1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Package pstest provides a fake Cloud PubSub service for testing. It implements a 16// simplified form of the service, suitable for unit tests. It may behave 17// differently from the actual service in ways in which the service is 18// non-deterministic or unspecified: timing, delivery order, etc. 19// 20// This package is EXPERIMENTAL and is subject to change without notice. 21// 22// See the example for usage. 23package pstest 24 25import ( 26 "context" 27 "fmt" 28 "io" 29 "path" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "time" 35 36 "cloud.google.com/go/internal/testutil" 37 "github.com/golang/protobuf/ptypes" 38 durpb "github.com/golang/protobuf/ptypes/duration" 39 emptypb "github.com/golang/protobuf/ptypes/empty" 40 pb "google.golang.org/genproto/googleapis/pubsub/v1" 41 "google.golang.org/grpc/codes" 42 "google.golang.org/grpc/status" 43) 44 45// For testing. Note that even though changes to the now variable are atomic, a call 46// to the stored function can race with a change to that function. This could be a 47// problem if tests are run in parallel, or even if concurrent parts of the same test 48// change the value of the variable. 49var now atomic.Value 50 51func init() { 52 now.Store(time.Now) 53 ResetMinAckDeadline() 54} 55 56func timeNow() time.Time { 57 return now.Load().(func() time.Time)() 58} 59 60// Server is a fake Pub/Sub server. 61type Server struct { 62 srv *testutil.Server 63 Addr string // The address that the server is listening on. 64 GServer GServer // Not intended to be used directly. 65} 66 67// GServer is the underlying service implementor. It is not intended to be used 68// directly. 69type GServer struct { 70 pb.PublisherServer 71 pb.SubscriberServer 72 73 mu sync.Mutex 74 topics map[string]*topic 75 subs map[string]*subscription 76 msgs []*Message // all messages ever published 77 msgsByID map[string]*Message 78 wg sync.WaitGroup 79 nextID int 80 streamTimeout time.Duration 81} 82 83// NewServer creates a new fake server running in the current process. 84func NewServer() *Server { 85 srv, err := testutil.NewServer() 86 if err != nil { 87 panic(fmt.Sprintf("pstest.NewServer: %v", err)) 88 } 89 s := &Server{ 90 srv: srv, 91 Addr: srv.Addr, 92 GServer: GServer{ 93 topics: map[string]*topic{}, 94 subs: map[string]*subscription{}, 95 msgsByID: map[string]*Message{}, 96 }, 97 } 98 pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) 99 pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer) 100 srv.Start() 101 return s 102} 103 104// Publish behaves as if the Publish RPC was called with a message with the given 105// data and attrs. It returns the ID of the message. 106// The topic will be created if it doesn't exist. 107// 108// Publish panics if there is an error, which is appropriate for testing. 109func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string { 110 const topicPattern = "projects/*/topics/*" 111 ok, err := path.Match(topicPattern, topic) 112 if err != nil { 113 panic(err) 114 } 115 if !ok { 116 panic(fmt.Sprintf("topic name must be of the form %q", topicPattern)) 117 } 118 _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic}) 119 req := &pb.PublishRequest{ 120 Topic: topic, 121 Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}}, 122 } 123 res, err := s.GServer.Publish(context.TODO(), req) 124 if err != nil { 125 panic(fmt.Sprintf("pstest.Server.Publish: %v", err)) 126 } 127 return res.MessageIds[0] 128} 129 130// SetStreamTimeout sets the amount of time a stream will be active before it shuts 131// itself down. This mimics the real service's behavior of closing streams after 30 132// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut 133// down. 134func (s *Server) SetStreamTimeout(d time.Duration) { 135 s.GServer.mu.Lock() 136 defer s.GServer.mu.Unlock() 137 s.GServer.streamTimeout = d 138} 139 140// A Message is a message that was published to the server. 141type Message struct { 142 ID string 143 Data []byte 144 Attributes map[string]string 145 PublishTime time.Time 146 Deliveries int // number of times delivery of the message was attempted 147 Acks int // number of acks received from clients 148 149 // protected by server mutex 150 deliveries int 151 acks int 152 Modacks []Modack // modacks received by server for this message 153 154} 155 156// Modack represents a modack sent to the server. 157type Modack struct { 158 AckID string 159 AckDeadline int32 160 ReceivedAt time.Time 161} 162 163// Messages returns information about all messages ever published. 164func (s *Server) Messages() []*Message { 165 s.GServer.mu.Lock() 166 defer s.GServer.mu.Unlock() 167 168 var msgs []*Message 169 for _, m := range s.GServer.msgs { 170 m.Deliveries = m.deliveries 171 m.Acks = m.acks 172 msgs = append(msgs, m) 173 } 174 return msgs 175} 176 177// Message returns the message with the given ID, or nil if no message 178// with that ID was published. 179func (s *Server) Message(id string) *Message { 180 s.GServer.mu.Lock() 181 defer s.GServer.mu.Unlock() 182 183 m := s.GServer.msgsByID[id] 184 if m != nil { 185 m.Deliveries = m.deliveries 186 m.Acks = m.acks 187 } 188 return m 189} 190 191// Wait blocks until all server activity has completed. 192func (s *Server) Wait() { 193 s.GServer.wg.Wait() 194} 195 196// Close shuts down the server and releases all resources. 197func (s *Server) Close() error { 198 s.srv.Close() 199 s.GServer.mu.Lock() 200 defer s.GServer.mu.Unlock() 201 for _, sub := range s.GServer.subs { 202 sub.stop() 203 } 204 return nil 205} 206 207func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { 208 s.mu.Lock() 209 defer s.mu.Unlock() 210 211 if s.topics[t.Name] != nil { 212 return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) 213 } 214 top := newTopic(t) 215 s.topics[t.Name] = top 216 return top.proto, nil 217} 218 219func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { 220 s.mu.Lock() 221 defer s.mu.Unlock() 222 223 if t := s.topics[req.Topic]; t != nil { 224 return t.proto, nil 225 } 226 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 227} 228 229func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) { 230 s.mu.Lock() 231 defer s.mu.Unlock() 232 233 t := s.topics[req.Topic.Name] 234 if t == nil { 235 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name) 236 } 237 for _, path := range req.UpdateMask.Paths { 238 switch path { 239 case "labels": 240 t.proto.Labels = req.Topic.Labels 241 case "message_storage_policy": // "fetch" the policy 242 t.proto.MessageStoragePolicy = &pb.MessageStoragePolicy{AllowedPersistenceRegions: []string{"US"}} 243 default: 244 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 245 } 246 } 247 return t.proto, nil 248} 249 250func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { 251 s.mu.Lock() 252 defer s.mu.Unlock() 253 254 var names []string 255 for n := range s.topics { 256 if strings.HasPrefix(n, req.Project) { 257 names = append(names, n) 258 } 259 } 260 sort.Strings(names) 261 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 262 if err != nil { 263 return nil, err 264 } 265 res := &pb.ListTopicsResponse{NextPageToken: nextToken} 266 for i := from; i < to; i++ { 267 res.Topics = append(res.Topics, s.topics[names[i]].proto) 268 } 269 return res, nil 270} 271 272func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { 273 s.mu.Lock() 274 defer s.mu.Unlock() 275 276 var names []string 277 for name, sub := range s.subs { 278 if sub.topic.proto.Name == req.Topic { 279 names = append(names, name) 280 } 281 } 282 sort.Strings(names) 283 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 284 if err != nil { 285 return nil, err 286 } 287 return &pb.ListTopicSubscriptionsResponse{ 288 Subscriptions: names[from:to], 289 NextPageToken: nextToken, 290 }, nil 291} 292 293func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { 294 s.mu.Lock() 295 defer s.mu.Unlock() 296 297 t := s.topics[req.Topic] 298 if t == nil { 299 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 300 } 301 t.stop() 302 delete(s.topics, req.Topic) 303 return &emptypb.Empty{}, nil 304} 305 306func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) { 307 s.mu.Lock() 308 defer s.mu.Unlock() 309 310 if ps.Name == "" { 311 return nil, status.Errorf(codes.InvalidArgument, "missing name") 312 } 313 if s.subs[ps.Name] != nil { 314 return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name) 315 } 316 if ps.Topic == "" { 317 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 318 } 319 top := s.topics[ps.Topic] 320 if top == nil { 321 return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic) 322 } 323 if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil { 324 return nil, err 325 } 326 if ps.MessageRetentionDuration == nil { 327 ps.MessageRetentionDuration = defaultMessageRetentionDuration 328 } 329 if err := checkMRD(ps.MessageRetentionDuration); err != nil { 330 return nil, err 331 } 332 if ps.PushConfig == nil { 333 ps.PushConfig = &pb.PushConfig{} 334 } 335 336 sub := newSubscription(top, &s.mu, ps) 337 top.subs[ps.Name] = sub 338 s.subs[ps.Name] = sub 339 sub.start(&s.wg) 340 return ps, nil 341} 342 343// Can be set for testing. 344var minAckDeadlineSecs int32 345 346// SetMinAckDeadline changes the minack deadline to n. Must be 347// greater than or equal to 1 second. Remember to reset this value 348// to the default after your test changes it. Example usage: 349// pstest.SetMinAckDeadlineSecs(1) 350// defer pstest.ResetMinAckDeadlineSecs() 351func SetMinAckDeadline(n time.Duration) { 352 if n < time.Second { 353 panic("SetMinAckDeadline expects a value greater than 1 second") 354 } 355 356 minAckDeadlineSecs = int32(n / time.Second) 357} 358 359// ResetMinAckDeadline resets the minack deadline to the default. 360func ResetMinAckDeadline() { 361 minAckDeadlineSecs = 10 362} 363 364func checkAckDeadline(ads int32) error { 365 if ads < minAckDeadlineSecs || ads > 600 { 366 // PubSub service returns Unknown. 367 return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) 368 } 369 return nil 370} 371 372const ( 373 minMessageRetentionDuration = 10 * time.Minute 374 maxMessageRetentionDuration = 168 * time.Hour 375) 376 377var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration) 378 379func checkMRD(pmrd *durpb.Duration) error { 380 mrd, err := ptypes.Duration(pmrd) 381 if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { 382 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) 383 } 384 return nil 385} 386 387func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { 388 s.mu.Lock() 389 defer s.mu.Unlock() 390 sub, err := s.findSubscription(req.Subscription) 391 if err != nil { 392 return nil, err 393 } 394 return sub.proto, nil 395} 396 397func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { 398 if req.Subscription == nil { 399 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 400 } 401 s.mu.Lock() 402 defer s.mu.Unlock() 403 sub, err := s.findSubscription(req.Subscription.Name) 404 if err != nil { 405 return nil, err 406 } 407 for _, path := range req.UpdateMask.Paths { 408 switch path { 409 case "push_config": 410 sub.proto.PushConfig = req.Subscription.PushConfig 411 412 case "ack_deadline_seconds": 413 a := req.Subscription.AckDeadlineSeconds 414 if err := checkAckDeadline(a); err != nil { 415 return nil, err 416 } 417 sub.proto.AckDeadlineSeconds = a 418 419 case "retain_acked_messages": 420 sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages 421 422 case "message_retention_duration": 423 if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { 424 return nil, err 425 } 426 sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration 427 428 case "labels": 429 sub.proto.Labels = req.Subscription.Labels 430 431 case "expiration_policy": 432 sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy 433 434 default: 435 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 436 } 437 } 438 return sub.proto, nil 439} 440 441func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { 442 s.mu.Lock() 443 defer s.mu.Unlock() 444 445 var names []string 446 for name := range s.subs { 447 if strings.HasPrefix(name, req.Project) { 448 names = append(names, name) 449 } 450 } 451 sort.Strings(names) 452 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 453 if err != nil { 454 return nil, err 455 } 456 res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} 457 for i := from; i < to; i++ { 458 res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto) 459 } 460 return res, nil 461} 462 463func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { 464 s.mu.Lock() 465 defer s.mu.Unlock() 466 sub, err := s.findSubscription(req.Subscription) 467 if err != nil { 468 return nil, err 469 } 470 sub.stop() 471 delete(s.subs, req.Subscription) 472 sub.topic.deleteSub(sub) 473 return &emptypb.Empty{}, nil 474} 475 476func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) { 477 s.mu.Lock() 478 defer s.mu.Unlock() 479 480 if req.Topic == "" { 481 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 482 } 483 top := s.topics[req.Topic] 484 if top == nil { 485 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 486 } 487 var ids []string 488 for _, pm := range req.Messages { 489 id := fmt.Sprintf("m%d", s.nextID) 490 s.nextID++ 491 pm.MessageId = id 492 pubTime := timeNow() 493 tsPubTime, err := ptypes.TimestampProto(pubTime) 494 if err != nil { 495 return nil, status.Errorf(codes.Internal, err.Error()) 496 } 497 pm.PublishTime = tsPubTime 498 m := &Message{ 499 ID: id, 500 Data: pm.Data, 501 Attributes: pm.Attributes, 502 PublishTime: pubTime, 503 } 504 top.publish(pm, m) 505 ids = append(ids, id) 506 s.msgs = append(s.msgs, m) 507 s.msgsByID[id] = m 508 } 509 return &pb.PublishResponse{MessageIds: ids}, nil 510} 511 512type topic struct { 513 proto *pb.Topic 514 subs map[string]*subscription 515} 516 517func newTopic(pt *pb.Topic) *topic { 518 return &topic{ 519 proto: pt, 520 subs: map[string]*subscription{}, 521 } 522} 523 524func (t *topic) stop() { 525 for _, sub := range t.subs { 526 sub.proto.Topic = "_deleted-topic_" 527 sub.stop() 528 } 529} 530 531func (t *topic) deleteSub(sub *subscription) { 532 delete(t.subs, sub.proto.Name) 533} 534 535func (t *topic) publish(pm *pb.PubsubMessage, m *Message) { 536 for _, s := range t.subs { 537 s.msgs[pm.MessageId] = &message{ 538 publishTime: m.PublishTime, 539 proto: &pb.ReceivedMessage{ 540 AckId: pm.MessageId, 541 Message: pm, 542 }, 543 deliveries: &m.deliveries, 544 acks: &m.acks, 545 streamIndex: -1, 546 } 547 } 548} 549 550type subscription struct { 551 topic *topic 552 mu *sync.Mutex // the server mutex, here for convenience 553 proto *pb.Subscription 554 ackTimeout time.Duration 555 msgs map[string]*message // unacked messages by message ID 556 streams []*stream 557 done chan struct{} 558} 559 560func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription { 561 at := time.Duration(ps.AckDeadlineSeconds) * time.Second 562 if at == 0 { 563 at = 10 * time.Second 564 } 565 return &subscription{ 566 topic: t, 567 mu: mu, 568 proto: ps, 569 ackTimeout: at, 570 msgs: map[string]*message{}, 571 done: make(chan struct{}), 572 } 573} 574 575func (s *subscription) start(wg *sync.WaitGroup) { 576 wg.Add(1) 577 go func() { 578 defer wg.Done() 579 for { 580 select { 581 case <-s.done: 582 return 583 case <-time.After(10 * time.Millisecond): 584 s.deliver() 585 } 586 } 587 }() 588} 589 590func (s *subscription) stop() { 591 close(s.done) 592} 593 594func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { 595 s.mu.Lock() 596 defer s.mu.Unlock() 597 598 sub, err := s.findSubscription(req.Subscription) 599 if err != nil { 600 return nil, err 601 } 602 for _, id := range req.AckIds { 603 sub.ack(id) 604 } 605 return &emptypb.Empty{}, nil 606} 607 608func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { 609 s.mu.Lock() 610 defer s.mu.Unlock() 611 sub, err := s.findSubscription(req.Subscription) 612 if err != nil { 613 return nil, err 614 } 615 now := time.Now() 616 for _, id := range req.AckIds { 617 s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now}) 618 } 619 dur := secsToDur(req.AckDeadlineSeconds) 620 for _, id := range req.AckIds { 621 sub.modifyAckDeadline(id, dur) 622 } 623 return &emptypb.Empty{}, nil 624} 625 626func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) { 627 s.mu.Lock() 628 sub, err := s.findSubscription(req.Subscription) 629 if err != nil { 630 s.mu.Unlock() 631 return nil, err 632 } 633 max := int(req.MaxMessages) 634 if max < 0 { 635 s.mu.Unlock() 636 return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative") 637 } 638 if max == 0 { // MaxMessages not specified; use a default. 639 max = 1000 640 } 641 msgs := sub.pull(max) 642 s.mu.Unlock() 643 // Implement the spec from the pubsub proto: 644 // "If ReturnImmediately set to true, the system will respond immediately even if 645 // it there are no messages available to return in the `Pull` response. 646 // Otherwise, the system may wait (for a bounded amount of time) until at 647 // least one message is available, rather than returning no messages." 648 if len(msgs) == 0 && !req.ReturnImmediately { 649 // Wait for a short amount of time for a message. 650 // TODO: signal when a message arrives, so we don't wait the whole time. 651 select { 652 case <-ctx.Done(): 653 return nil, ctx.Err() 654 case <-time.After(500 * time.Millisecond): 655 s.mu.Lock() 656 msgs = sub.pull(max) 657 s.mu.Unlock() 658 } 659 } 660 return &pb.PullResponse{ReceivedMessages: msgs}, nil 661} 662 663func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error { 664 // Receive initial message configuring the pull. 665 req, err := sps.Recv() 666 if err != nil { 667 return err 668 } 669 s.mu.Lock() 670 sub, err := s.findSubscription(req.Subscription) 671 s.mu.Unlock() 672 if err != nil { 673 return err 674 } 675 // Create a new stream to handle the pull. 676 st := sub.newStream(sps, s.streamTimeout) 677 err = st.pull(&s.wg) 678 sub.deleteStream(st) 679 return err 680} 681 682func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) { 683 // Only handle time-based seeking for now. 684 // This fake doesn't deal with snapshots. 685 var target time.Time 686 switch v := req.Target.(type) { 687 case nil: 688 return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type") 689 case *pb.SeekRequest_Time: 690 var err error 691 target, err = ptypes.Timestamp(v.Time) 692 if err != nil { 693 return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err) 694 } 695 default: 696 return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v) 697 } 698 699 // The entire server must be locked while doing the work below, 700 // because the messages don't have any other synchronization. 701 s.mu.Lock() 702 defer s.mu.Unlock() 703 sub, err := s.findSubscription(req.Subscription) 704 if err != nil { 705 return nil, err 706 } 707 // Drop all messages from sub that were published before the target time. 708 for id, m := range sub.msgs { 709 if m.publishTime.Before(target) { 710 delete(sub.msgs, id) 711 (*m.acks)++ 712 } 713 } 714 // Un-ack any already-acked messages after this time; 715 // redelivering them to the subscription is the closest analogue here. 716 for _, m := range s.msgs { 717 if m.PublishTime.Before(target) { 718 continue 719 } 720 sub.msgs[m.ID] = &message{ 721 publishTime: m.PublishTime, 722 proto: &pb.ReceivedMessage{ 723 AckId: m.ID, 724 // This was not preserved! 725 //Message: pm, 726 }, 727 deliveries: &m.deliveries, 728 acks: &m.acks, 729 streamIndex: -1, 730 } 731 } 732 return &pb.SeekResponse{}, nil 733} 734 735// Gets a subscription that must exist. 736// Must be called with the lock held. 737func (s *GServer) findSubscription(name string) (*subscription, error) { 738 if name == "" { 739 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 740 } 741 sub := s.subs[name] 742 if sub == nil { 743 return nil, status.Errorf(codes.NotFound, "subscription %s", name) 744 } 745 return sub, nil 746} 747 748// Must be called with the lock held. 749func (s *subscription) pull(max int) []*pb.ReceivedMessage { 750 now := timeNow() 751 s.maintainMessages(now) 752 var msgs []*pb.ReceivedMessage 753 for _, m := range s.msgs { 754 if m.outstanding() { 755 continue 756 } 757 (*m.deliveries)++ 758 m.ackDeadline = now.Add(s.ackTimeout) 759 msgs = append(msgs, m.proto) 760 if len(msgs) >= max { 761 break 762 } 763 } 764 return msgs 765} 766 767func (s *subscription) deliver() { 768 s.mu.Lock() 769 defer s.mu.Unlock() 770 771 now := timeNow() 772 s.maintainMessages(now) 773 // Try to deliver each remaining message. 774 curIndex := 0 775 for _, m := range s.msgs { 776 if m.outstanding() { 777 continue 778 } 779 // If the message was never delivered before, start with the stream at 780 // curIndex. If it was delivered before, start with the stream after the one 781 // that owned it. 782 if m.streamIndex < 0 { 783 delIndex, ok := s.tryDeliverMessage(m, curIndex, now) 784 if !ok { 785 break 786 } 787 curIndex = delIndex + 1 788 m.streamIndex = curIndex 789 } else { 790 delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now) 791 if !ok { 792 break 793 } 794 m.streamIndex = delIndex 795 } 796 } 797} 798 799// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it 800// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it 801// exits. 802// 803// It returns the index of the stream it delivered the message to, or 0, false if 804// it didn't deliver the message. 805// 806// Must be called with the lock held. 807func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) { 808 for i := 0; i < len(s.streams); i++ { 809 idx := (i + start) % len(s.streams) 810 811 st := s.streams[idx] 812 select { 813 case <-st.done: 814 s.streams = deleteStreamAt(s.streams, idx) 815 i-- 816 817 case st.msgc <- m.proto: 818 (*m.deliveries)++ 819 m.ackDeadline = now.Add(st.ackTimeout) 820 return idx, true 821 822 default: 823 } 824 } 825 return 0, false 826} 827 828var retentionDuration = 10 * time.Minute 829 830// Must be called with the lock held. 831func (s *subscription) maintainMessages(now time.Time) { 832 for id, m := range s.msgs { 833 // Mark a message as re-deliverable if its ack deadline has expired. 834 if m.outstanding() && now.After(m.ackDeadline) { 835 m.makeAvailable() 836 } 837 pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime) 838 if err != nil { 839 panic(err) 840 } 841 // Remove messages that have been undelivered for a long time. 842 if !m.outstanding() && now.Sub(pubTime) > retentionDuration { 843 delete(s.msgs, id) 844 } 845 } 846} 847 848func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream { 849 st := &stream{ 850 sub: s, 851 done: make(chan struct{}), 852 msgc: make(chan *pb.ReceivedMessage), 853 gstream: gs, 854 ackTimeout: s.ackTimeout, 855 timeout: timeout, 856 } 857 s.mu.Lock() 858 s.streams = append(s.streams, st) 859 s.mu.Unlock() 860 return st 861} 862 863func (s *subscription) deleteStream(st *stream) { 864 s.mu.Lock() 865 defer s.mu.Unlock() 866 var i int 867 for i = 0; i < len(s.streams); i++ { 868 if s.streams[i] == st { 869 break 870 } 871 } 872 if i < len(s.streams) { 873 s.streams = deleteStreamAt(s.streams, i) 874 } 875} 876func deleteStreamAt(s []*stream, i int) []*stream { 877 // Preserve order for round-robin delivery. 878 return append(s[:i], s[i+1:]...) 879} 880 881type message struct { 882 proto *pb.ReceivedMessage 883 publishTime time.Time 884 ackDeadline time.Time 885 deliveries *int 886 acks *int 887 streamIndex int // index of stream that currently owns msg, for round-robin delivery 888} 889 890// A message is outstanding if it is owned by some stream. 891func (m *message) outstanding() bool { 892 return !m.ackDeadline.IsZero() 893} 894 895func (m *message) makeAvailable() { 896 m.ackDeadline = time.Time{} 897} 898 899type stream struct { 900 sub *subscription 901 done chan struct{} // closed when the stream is finished 902 msgc chan *pb.ReceivedMessage 903 gstream pb.Subscriber_StreamingPullServer 904 ackTimeout time.Duration 905 timeout time.Duration 906} 907 908// pull manages the StreamingPull interaction for the life of the stream. 909func (st *stream) pull(wg *sync.WaitGroup) error { 910 errc := make(chan error, 2) 911 wg.Add(2) 912 go func() { 913 defer wg.Done() 914 errc <- st.sendLoop() 915 }() 916 go func() { 917 defer wg.Done() 918 errc <- st.recvLoop() 919 }() 920 var tchan <-chan time.Time 921 if st.timeout > 0 { 922 tchan = time.After(st.timeout) 923 } 924 // Wait until one of the goroutines returns an error, or we time out. 925 var err error 926 select { 927 case err = <-errc: 928 if err == io.EOF { 929 err = nil 930 } 931 case <-tchan: 932 } 933 close(st.done) // stop the other goroutine 934 return err 935} 936 937func (st *stream) sendLoop() error { 938 for { 939 select { 940 case <-st.done: 941 return nil 942 case rm := <-st.msgc: 943 res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}} 944 if err := st.gstream.Send(res); err != nil { 945 return err 946 } 947 } 948 } 949} 950 951func (st *stream) recvLoop() error { 952 for { 953 req, err := st.gstream.Recv() 954 if err != nil { 955 return err 956 } 957 st.sub.handleStreamingPullRequest(st, req) 958 } 959} 960 961func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) { 962 // Lock the entire server. 963 s.mu.Lock() 964 defer s.mu.Unlock() 965 966 for _, ackID := range req.AckIds { 967 s.ack(ackID) 968 } 969 for i, id := range req.ModifyDeadlineAckIds { 970 s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i])) 971 } 972 if req.StreamAckDeadlineSeconds > 0 { 973 st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds) 974 } 975} 976 977// Must be called with the lock held. 978func (s *subscription) ack(id string) { 979 m := s.msgs[id] 980 if m != nil { 981 (*m.acks)++ 982 delete(s.msgs, id) 983 } 984} 985 986// Must be called with the lock held. 987func (s *subscription) modifyAckDeadline(id string, d time.Duration) { 988 m := s.msgs[id] 989 if m == nil { // already acked: ignore. 990 return 991 } 992 if d == 0 { // nack 993 m.makeAvailable() 994 } else { // extend the deadline by d 995 m.ackDeadline = timeNow().Add(d) 996 } 997} 998 999func secsToDur(secs int32) time.Duration { 1000 return time.Duration(secs) * time.Second 1001} 1002