1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Package pstest provides a fake Cloud PubSub service for testing. It implements a 16// simplified form of the service, suitable for unit tests. It may behave 17// differently from the actual service in ways in which the service is 18// non-deterministic or unspecified: timing, delivery order, etc. 19// 20// This package is EXPERIMENTAL and is subject to change without notice. 21// 22// See the example for usage. 23package pstest 24 25import ( 26 "context" 27 "fmt" 28 "io" 29 "path" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "time" 35 36 "cloud.google.com/go/internal/testutil" 37 pb "google.golang.org/genproto/googleapis/pubsub/v1" 38 "google.golang.org/grpc/codes" 39 "google.golang.org/grpc/status" 40 durpb "google.golang.org/protobuf/types/known/durationpb" 41 "google.golang.org/protobuf/types/known/emptypb" 42 "google.golang.org/protobuf/types/known/timestamppb" 43) 44 45// ReactorOptions is a map that Server uses to look up reactors. 46// Key is the function name, value is array of reactor for the function. 47type ReactorOptions map[string][]Reactor 48 49// Reactor is an interface to allow reaction function to a certain call. 50type Reactor interface { 51 // React handles the message types and returns results. If "handled" is false, 52 // then the test server will ignore the results and continue to the next reactor 53 // or the original handler. 54 React(_ interface{}) (handled bool, ret interface{}, err error) 55} 56 57// ServerReactorOption is options passed to the server for reactor creation. 58type ServerReactorOption struct { 59 FuncName string 60 Reactor Reactor 61} 62 63// For testing. Note that even though changes to the now variable are atomic, a call 64// to the stored function can race with a change to that function. This could be a 65// problem if tests are run in parallel, or even if concurrent parts of the same test 66// change the value of the variable. 67var now atomic.Value 68 69func init() { 70 now.Store(time.Now) 71 ResetMinAckDeadline() 72} 73 74func timeNow() time.Time { 75 return now.Load().(func() time.Time)() 76} 77 78// Server is a fake Pub/Sub server. 79type Server struct { 80 srv *testutil.Server 81 Addr string // The address that the server is listening on. 82 GServer GServer // Not intended to be used directly. 83} 84 85// GServer is the underlying service implementor. It is not intended to be used 86// directly. 87type GServer struct { 88 pb.PublisherServer 89 pb.SubscriberServer 90 91 mu sync.Mutex 92 topics map[string]*topic 93 subs map[string]*subscription 94 msgs []*Message // all messages ever published 95 msgsByID map[string]*Message 96 wg sync.WaitGroup 97 nextID int 98 streamTimeout time.Duration 99 timeNowFunc func() time.Time 100 reactorOptions ReactorOptions 101} 102 103// NewServer creates a new fake server running in the current process. 104func NewServer(opts ...ServerReactorOption) *Server { 105 srv, err := testutil.NewServer() 106 if err != nil { 107 panic(fmt.Sprintf("pstest.NewServer: %v", err)) 108 } 109 reactorOptions := ReactorOptions{} 110 for _, opt := range opts { 111 reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor) 112 } 113 s := &Server{ 114 srv: srv, 115 Addr: srv.Addr, 116 GServer: GServer{ 117 topics: map[string]*topic{}, 118 subs: map[string]*subscription{}, 119 msgsByID: map[string]*Message{}, 120 timeNowFunc: timeNow, 121 reactorOptions: reactorOptions, 122 }, 123 } 124 pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) 125 pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer) 126 srv.Start() 127 return s 128} 129 130// SetTimeNowFunc registers f as a function to 131// be used instead of time.Now for this server. 132func (s *Server) SetTimeNowFunc(f func() time.Time) { 133 s.GServer.timeNowFunc = f 134} 135 136// Publish behaves as if the Publish RPC was called with a message with the given 137// data and attrs. It returns the ID of the message. 138// The topic will be created if it doesn't exist. 139// 140// Publish panics if there is an error, which is appropriate for testing. 141func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string { 142 return s.PublishOrdered(topic, data, attrs, "") 143} 144 145// PublishOrdered behaves as if the Publish RPC was called with a message with the given 146// data, attrs and ordering key. It returns the ID of the message. 147// The topic will be created if it doesn't exist. 148// 149// PublishOrdered panics if there is an error, which is appropriate for testing. 150func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string { 151 const topicPattern = "projects/*/topics/*" 152 ok, err := path.Match(topicPattern, topic) 153 if err != nil { 154 panic(err) 155 } 156 if !ok { 157 panic(fmt.Sprintf("topic name must be of the form %q", topicPattern)) 158 } 159 _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic}) 160 req := &pb.PublishRequest{ 161 Topic: topic, 162 Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}}, 163 } 164 res, err := s.GServer.Publish(context.TODO(), req) 165 if err != nil { 166 panic(fmt.Sprintf("pstest.Server.Publish: %v", err)) 167 } 168 return res.MessageIds[0] 169} 170 171// SetStreamTimeout sets the amount of time a stream will be active before it shuts 172// itself down. This mimics the real service's behavior of closing streams after 30 173// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut 174// down. 175func (s *Server) SetStreamTimeout(d time.Duration) { 176 s.GServer.mu.Lock() 177 defer s.GServer.mu.Unlock() 178 s.GServer.streamTimeout = d 179} 180 181// A Message is a message that was published to the server. 182type Message struct { 183 ID string 184 Data []byte 185 Attributes map[string]string 186 PublishTime time.Time 187 Deliveries int // number of times delivery of the message was attempted 188 Acks int // number of acks received from clients 189 Modacks []Modack // modacks received by server for this message 190 OrderingKey string 191 192 // protected by server mutex 193 deliveries int 194 acks int 195 modacks []Modack 196} 197 198// Modack represents a modack sent to the server. 199type Modack struct { 200 AckID string 201 AckDeadline int32 202 ReceivedAt time.Time 203} 204 205// Messages returns information about all messages ever published. 206func (s *Server) Messages() []*Message { 207 s.GServer.mu.Lock() 208 defer s.GServer.mu.Unlock() 209 210 var msgs []*Message 211 for _, m := range s.GServer.msgs { 212 m.Deliveries = m.deliveries 213 m.Acks = m.acks 214 m.Modacks = append([]Modack(nil), m.modacks...) 215 msgs = append(msgs, m) 216 } 217 return msgs 218} 219 220// Message returns the message with the given ID, or nil if no message 221// with that ID was published. 222func (s *Server) Message(id string) *Message { 223 s.GServer.mu.Lock() 224 defer s.GServer.mu.Unlock() 225 226 m := s.GServer.msgsByID[id] 227 if m != nil { 228 m.Deliveries = m.deliveries 229 m.Acks = m.acks 230 m.Modacks = append([]Modack(nil), m.modacks...) 231 } 232 return m 233} 234 235// Wait blocks until all server activity has completed. 236func (s *Server) Wait() { 237 s.GServer.wg.Wait() 238} 239 240// ClearMessages removes all published messages 241// from internal containers. 242func (s *Server) ClearMessages() { 243 s.GServer.mu.Lock() 244 s.GServer.msgs = nil 245 s.GServer.msgsByID = make(map[string]*Message) 246 s.GServer.mu.Unlock() 247} 248 249// Close shuts down the server and releases all resources. 250func (s *Server) Close() error { 251 s.srv.Close() 252 s.GServer.mu.Lock() 253 defer s.GServer.mu.Unlock() 254 for _, sub := range s.GServer.subs { 255 sub.stop() 256 } 257 return nil 258} 259 260func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { 261 s.mu.Lock() 262 defer s.mu.Unlock() 263 264 if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil { 265 return ret.(*pb.Topic), err 266 } 267 268 if s.topics[t.Name] != nil { 269 return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) 270 } 271 top := newTopic(t) 272 s.topics[t.Name] = top 273 return top.proto, nil 274} 275 276func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { 277 s.mu.Lock() 278 defer s.mu.Unlock() 279 280 if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil { 281 return ret.(*pb.Topic), err 282 } 283 284 if t := s.topics[req.Topic]; t != nil { 285 return t.proto, nil 286 } 287 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 288} 289 290func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) { 291 s.mu.Lock() 292 defer s.mu.Unlock() 293 294 if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil { 295 return ret.(*pb.Topic), err 296 } 297 298 t := s.topics[req.Topic.Name] 299 if t == nil { 300 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name) 301 } 302 for _, path := range req.UpdateMask.Paths { 303 switch path { 304 case "labels": 305 t.proto.Labels = req.Topic.Labels 306 case "message_storage_policy": 307 t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy 308 default: 309 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 310 } 311 } 312 return t.proto, nil 313} 314 315func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { 316 s.mu.Lock() 317 defer s.mu.Unlock() 318 319 if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil { 320 return ret.(*pb.ListTopicsResponse), err 321 } 322 323 var names []string 324 for n := range s.topics { 325 if strings.HasPrefix(n, req.Project) { 326 names = append(names, n) 327 } 328 } 329 sort.Strings(names) 330 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 331 if err != nil { 332 return nil, err 333 } 334 res := &pb.ListTopicsResponse{NextPageToken: nextToken} 335 for i := from; i < to; i++ { 336 res.Topics = append(res.Topics, s.topics[names[i]].proto) 337 } 338 return res, nil 339} 340 341func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { 342 s.mu.Lock() 343 defer s.mu.Unlock() 344 345 if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil { 346 return ret.(*pb.ListTopicSubscriptionsResponse), err 347 } 348 349 var names []string 350 for name, sub := range s.subs { 351 if sub.topic.proto.Name == req.Topic { 352 names = append(names, name) 353 } 354 } 355 sort.Strings(names) 356 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 357 if err != nil { 358 return nil, err 359 } 360 return &pb.ListTopicSubscriptionsResponse{ 361 Subscriptions: names[from:to], 362 NextPageToken: nextToken, 363 }, nil 364} 365 366func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { 367 s.mu.Lock() 368 defer s.mu.Unlock() 369 370 if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil { 371 return ret.(*emptypb.Empty), err 372 } 373 374 t := s.topics[req.Topic] 375 if t == nil { 376 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 377 } 378 t.stop() 379 delete(s.topics, req.Topic) 380 return &emptypb.Empty{}, nil 381} 382 383func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) { 384 s.mu.Lock() 385 defer s.mu.Unlock() 386 387 if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil { 388 return ret.(*pb.Subscription), err 389 } 390 391 if ps.Name == "" { 392 return nil, status.Errorf(codes.InvalidArgument, "missing name") 393 } 394 if s.subs[ps.Name] != nil { 395 return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name) 396 } 397 if ps.Topic == "" { 398 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 399 } 400 top := s.topics[ps.Topic] 401 if top == nil { 402 return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic) 403 } 404 if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil { 405 return nil, err 406 } 407 if ps.MessageRetentionDuration == nil { 408 ps.MessageRetentionDuration = defaultMessageRetentionDuration 409 } 410 if err := checkMRD(ps.MessageRetentionDuration); err != nil { 411 return nil, err 412 } 413 if ps.PushConfig == nil { 414 ps.PushConfig = &pb.PushConfig{} 415 } 416 417 sub := newSubscription(top, &s.mu, s.timeNowFunc, ps) 418 top.subs[ps.Name] = sub 419 s.subs[ps.Name] = sub 420 sub.start(&s.wg) 421 return ps, nil 422} 423 424// Can be set for testing. 425var minAckDeadlineSecs int32 426 427// SetMinAckDeadline changes the minack deadline to n. Must be 428// greater than or equal to 1 second. Remember to reset this value 429// to the default after your test changes it. Example usage: 430// pstest.SetMinAckDeadlineSecs(1) 431// defer pstest.ResetMinAckDeadlineSecs() 432func SetMinAckDeadline(n time.Duration) { 433 if n < time.Second { 434 panic("SetMinAckDeadline expects a value greater than 1 second") 435 } 436 437 minAckDeadlineSecs = int32(n / time.Second) 438} 439 440// ResetMinAckDeadline resets the minack deadline to the default. 441func ResetMinAckDeadline() { 442 minAckDeadlineSecs = 10 443} 444 445func checkAckDeadline(ads int32) error { 446 if ads < minAckDeadlineSecs || ads > 600 { 447 // PubSub service returns Unknown. 448 return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) 449 } 450 return nil 451} 452 453const ( 454 minMessageRetentionDuration = 10 * time.Minute 455 maxMessageRetentionDuration = 168 * time.Hour 456) 457 458var defaultMessageRetentionDuration = durpb.New(maxMessageRetentionDuration) 459 460func checkMRD(pmrd *durpb.Duration) error { 461 mrd := pmrd.AsDuration() 462 if mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { 463 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) 464 } 465 return nil 466} 467 468func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { 469 s.mu.Lock() 470 defer s.mu.Unlock() 471 472 if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil { 473 return ret.(*pb.Subscription), err 474 } 475 476 sub, err := s.findSubscription(req.Subscription) 477 if err != nil { 478 return nil, err 479 } 480 return sub.proto, nil 481} 482 483func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { 484 if req.Subscription == nil { 485 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 486 } 487 s.mu.Lock() 488 defer s.mu.Unlock() 489 490 if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil { 491 return ret.(*pb.Subscription), err 492 } 493 494 sub, err := s.findSubscription(req.Subscription.Name) 495 if err != nil { 496 return nil, err 497 } 498 for _, path := range req.UpdateMask.Paths { 499 switch path { 500 case "push_config": 501 sub.proto.PushConfig = req.Subscription.PushConfig 502 503 case "ack_deadline_seconds": 504 a := req.Subscription.AckDeadlineSeconds 505 if err := checkAckDeadline(a); err != nil { 506 return nil, err 507 } 508 sub.proto.AckDeadlineSeconds = a 509 510 case "retain_acked_messages": 511 sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages 512 513 case "message_retention_duration": 514 if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { 515 return nil, err 516 } 517 sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration 518 519 case "labels": 520 sub.proto.Labels = req.Subscription.Labels 521 522 case "expiration_policy": 523 sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy 524 525 case "dead_letter_policy": 526 sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy 527 528 case "retry_policy": 529 sub.proto.RetryPolicy = req.Subscription.RetryPolicy 530 531 case "filter": 532 sub.proto.Filter = req.Subscription.Filter 533 534 default: 535 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 536 } 537 } 538 return sub.proto, nil 539} 540 541func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { 542 s.mu.Lock() 543 defer s.mu.Unlock() 544 545 if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil { 546 return ret.(*pb.ListSubscriptionsResponse), err 547 } 548 549 var names []string 550 for name := range s.subs { 551 if strings.HasPrefix(name, req.Project) { 552 names = append(names, name) 553 } 554 } 555 sort.Strings(names) 556 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 557 if err != nil { 558 return nil, err 559 } 560 res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} 561 for i := from; i < to; i++ { 562 res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto) 563 } 564 return res, nil 565} 566 567func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { 568 s.mu.Lock() 569 defer s.mu.Unlock() 570 571 if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil { 572 return ret.(*emptypb.Empty), err 573 } 574 575 sub, err := s.findSubscription(req.Subscription) 576 if err != nil { 577 return nil, err 578 } 579 sub.stop() 580 delete(s.subs, req.Subscription) 581 sub.topic.deleteSub(sub) 582 return &emptypb.Empty{}, nil 583} 584 585func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) { 586 s.mu.Lock() 587 defer s.mu.Unlock() 588 589 if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil { 590 return ret.(*pb.DetachSubscriptionResponse), err 591 } 592 593 sub, err := s.findSubscription(req.Subscription) 594 if err != nil { 595 return nil, err 596 } 597 sub.topic.deleteSub(sub) 598 return &pb.DetachSubscriptionResponse{}, nil 599} 600 601func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) { 602 s.mu.Lock() 603 defer s.mu.Unlock() 604 605 if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil { 606 return ret.(*pb.PublishResponse), err 607 } 608 609 if req.Topic == "" { 610 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 611 } 612 top := s.topics[req.Topic] 613 if top == nil { 614 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 615 } 616 var ids []string 617 for _, pm := range req.Messages { 618 id := fmt.Sprintf("m%d", s.nextID) 619 s.nextID++ 620 pm.MessageId = id 621 pubTime := s.timeNowFunc() 622 tsPubTime := timestamppb.New(pubTime) 623 pm.PublishTime = tsPubTime 624 m := &Message{ 625 ID: id, 626 Data: pm.Data, 627 Attributes: pm.Attributes, 628 PublishTime: pubTime, 629 OrderingKey: pm.OrderingKey, 630 } 631 top.publish(pm, m) 632 ids = append(ids, id) 633 s.msgs = append(s.msgs, m) 634 s.msgsByID[id] = m 635 } 636 return &pb.PublishResponse{MessageIds: ids}, nil 637} 638 639type topic struct { 640 proto *pb.Topic 641 subs map[string]*subscription 642} 643 644func newTopic(pt *pb.Topic) *topic { 645 return &topic{ 646 proto: pt, 647 subs: map[string]*subscription{}, 648 } 649} 650 651func (t *topic) stop() { 652 for _, sub := range t.subs { 653 sub.proto.Topic = "_deleted-topic_" 654 } 655} 656 657func (t *topic) deleteSub(sub *subscription) { 658 delete(t.subs, sub.proto.Name) 659} 660 661func (t *topic) publish(pm *pb.PubsubMessage, m *Message) { 662 for _, s := range t.subs { 663 s.msgs[pm.MessageId] = &message{ 664 publishTime: m.PublishTime, 665 proto: &pb.ReceivedMessage{ 666 AckId: pm.MessageId, 667 Message: pm, 668 }, 669 deliveries: &m.deliveries, 670 acks: &m.acks, 671 streamIndex: -1, 672 } 673 } 674} 675 676type subscription struct { 677 topic *topic 678 mu *sync.Mutex // the server mutex, here for convenience 679 proto *pb.Subscription 680 ackTimeout time.Duration 681 msgs map[string]*message // unacked messages by message ID 682 streams []*stream 683 done chan struct{} 684 timeNowFunc func() time.Time 685} 686 687func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, ps *pb.Subscription) *subscription { 688 at := time.Duration(ps.AckDeadlineSeconds) * time.Second 689 if at == 0 { 690 at = 10 * time.Second 691 } 692 return &subscription{ 693 topic: t, 694 mu: mu, 695 proto: ps, 696 ackTimeout: at, 697 msgs: map[string]*message{}, 698 done: make(chan struct{}), 699 timeNowFunc: timeNowFunc, 700 } 701} 702 703func (s *subscription) start(wg *sync.WaitGroup) { 704 wg.Add(1) 705 go func() { 706 defer wg.Done() 707 for { 708 select { 709 case <-s.done: 710 return 711 case <-time.After(10 * time.Millisecond): 712 s.deliver() 713 } 714 } 715 }() 716} 717 718func (s *subscription) stop() { 719 close(s.done) 720} 721 722func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { 723 s.mu.Lock() 724 defer s.mu.Unlock() 725 726 if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil { 727 return ret.(*emptypb.Empty), err 728 } 729 730 sub, err := s.findSubscription(req.Subscription) 731 if err != nil { 732 return nil, err 733 } 734 for _, id := range req.AckIds { 735 sub.ack(id) 736 } 737 return &emptypb.Empty{}, nil 738} 739 740func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { 741 s.mu.Lock() 742 defer s.mu.Unlock() 743 744 if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil { 745 return ret.(*emptypb.Empty), err 746 } 747 748 sub, err := s.findSubscription(req.Subscription) 749 if err != nil { 750 return nil, err 751 } 752 now := time.Now() 753 for _, id := range req.AckIds { 754 s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now}) 755 } 756 dur := secsToDur(req.AckDeadlineSeconds) 757 for _, id := range req.AckIds { 758 sub.modifyAckDeadline(id, dur) 759 } 760 return &emptypb.Empty{}, nil 761} 762 763func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) { 764 s.mu.Lock() 765 766 if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil { 767 s.mu.Unlock() 768 return ret.(*pb.PullResponse), err 769 } 770 771 sub, err := s.findSubscription(req.Subscription) 772 if err != nil { 773 s.mu.Unlock() 774 return nil, err 775 } 776 max := int(req.MaxMessages) 777 if max < 0 { 778 s.mu.Unlock() 779 return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative") 780 } 781 if max == 0 { // MaxMessages not specified; use a default. 782 max = 1000 783 } 784 msgs := sub.pull(max) 785 s.mu.Unlock() 786 // Implement the spec from the pubsub proto: 787 // "If ReturnImmediately set to true, the system will respond immediately even if 788 // it there are no messages available to return in the `Pull` response. 789 // Otherwise, the system may wait (for a bounded amount of time) until at 790 // least one message is available, rather than returning no messages." 791 if len(msgs) == 0 && !req.ReturnImmediately { 792 // Wait for a short amount of time for a message. 793 // TODO: signal when a message arrives, so we don't wait the whole time. 794 select { 795 case <-ctx.Done(): 796 return nil, ctx.Err() 797 case <-time.After(500 * time.Millisecond): 798 s.mu.Lock() 799 msgs = sub.pull(max) 800 s.mu.Unlock() 801 } 802 } 803 return &pb.PullResponse{ReceivedMessages: msgs}, nil 804} 805 806func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error { 807 // Receive initial message configuring the pull. 808 req, err := sps.Recv() 809 if err != nil { 810 return err 811 } 812 s.mu.Lock() 813 sub, err := s.findSubscription(req.Subscription) 814 s.mu.Unlock() 815 if err != nil { 816 return err 817 } 818 // Create a new stream to handle the pull. 819 st := sub.newStream(sps, s.streamTimeout) 820 err = st.pull(&s.wg) 821 sub.deleteStream(st) 822 return err 823} 824 825func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) { 826 // Only handle time-based seeking for now. 827 // This fake doesn't deal with snapshots. 828 var target time.Time 829 switch v := req.Target.(type) { 830 case nil: 831 return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type") 832 case *pb.SeekRequest_Time: 833 target = v.Time.AsTime() 834 default: 835 return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v) 836 } 837 838 // The entire server must be locked while doing the work below, 839 // because the messages don't have any other synchronization. 840 s.mu.Lock() 841 defer s.mu.Unlock() 842 843 if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil { 844 return ret.(*pb.SeekResponse), err 845 } 846 847 sub, err := s.findSubscription(req.Subscription) 848 if err != nil { 849 return nil, err 850 } 851 // Drop all messages from sub that were published before the target time. 852 for id, m := range sub.msgs { 853 if m.publishTime.Before(target) { 854 delete(sub.msgs, id) 855 (*m.acks)++ 856 } 857 } 858 // Un-ack any already-acked messages after this time; 859 // redelivering them to the subscription is the closest analogue here. 860 for _, m := range s.msgs { 861 if m.PublishTime.Before(target) { 862 continue 863 } 864 sub.msgs[m.ID] = &message{ 865 publishTime: m.PublishTime, 866 proto: &pb.ReceivedMessage{ 867 AckId: m.ID, 868 // This was not preserved! 869 //Message: pm, 870 }, 871 deliveries: &m.deliveries, 872 acks: &m.acks, 873 streamIndex: -1, 874 } 875 } 876 return &pb.SeekResponse{}, nil 877} 878 879// Gets a subscription that must exist. 880// Must be called with the lock held. 881func (s *GServer) findSubscription(name string) (*subscription, error) { 882 if name == "" { 883 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 884 } 885 sub := s.subs[name] 886 if sub == nil { 887 return nil, status.Errorf(codes.NotFound, "subscription %s", name) 888 } 889 return sub, nil 890} 891 892// Must be called with the lock held. 893func (s *subscription) pull(max int) []*pb.ReceivedMessage { 894 now := s.timeNowFunc() 895 s.maintainMessages(now) 896 var msgs []*pb.ReceivedMessage 897 for _, m := range s.msgs { 898 if m.outstanding() { 899 continue 900 } 901 (*m.deliveries)++ 902 m.ackDeadline = now.Add(s.ackTimeout) 903 msgs = append(msgs, m.proto) 904 if len(msgs) >= max { 905 break 906 } 907 } 908 return msgs 909} 910 911func (s *subscription) deliver() { 912 s.mu.Lock() 913 defer s.mu.Unlock() 914 915 now := s.timeNowFunc() 916 s.maintainMessages(now) 917 // Try to deliver each remaining message. 918 curIndex := 0 919 for _, m := range s.msgs { 920 if m.outstanding() { 921 continue 922 } 923 // If the message was never delivered before, start with the stream at 924 // curIndex. If it was delivered before, start with the stream after the one 925 // that owned it. 926 if m.streamIndex < 0 { 927 delIndex, ok := s.tryDeliverMessage(m, curIndex, now) 928 if !ok { 929 break 930 } 931 curIndex = delIndex + 1 932 m.streamIndex = curIndex 933 } else { 934 delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now) 935 if !ok { 936 break 937 } 938 m.streamIndex = delIndex 939 } 940 } 941} 942 943// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it 944// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it 945// exits. 946// 947// It returns the index of the stream it delivered the message to, or 0, false if 948// it didn't deliver the message. 949// 950// Must be called with the lock held. 951func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) { 952 for i := 0; i < len(s.streams); i++ { 953 idx := (i + start) % len(s.streams) 954 955 st := s.streams[idx] 956 select { 957 case <-st.done: 958 s.streams = deleteStreamAt(s.streams, idx) 959 i-- 960 961 case st.msgc <- m.proto: 962 (*m.deliveries)++ 963 m.ackDeadline = now.Add(st.ackTimeout) 964 return idx, true 965 966 default: 967 } 968 } 969 return 0, false 970} 971 972var retentionDuration = 10 * time.Minute 973 974// Must be called with the lock held. 975func (s *subscription) maintainMessages(now time.Time) { 976 for id, m := range s.msgs { 977 // Mark a message as re-deliverable if its ack deadline has expired. 978 if m.outstanding() && now.After(m.ackDeadline) { 979 m.makeAvailable() 980 } 981 pubTime := m.proto.Message.PublishTime.AsTime() 982 // Remove messages that have been undelivered for a long time. 983 if !m.outstanding() && now.Sub(pubTime) > retentionDuration { 984 delete(s.msgs, id) 985 } 986 } 987} 988 989func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream { 990 st := &stream{ 991 sub: s, 992 done: make(chan struct{}), 993 msgc: make(chan *pb.ReceivedMessage), 994 gstream: gs, 995 ackTimeout: s.ackTimeout, 996 timeout: timeout, 997 } 998 s.mu.Lock() 999 s.streams = append(s.streams, st) 1000 s.mu.Unlock() 1001 return st 1002} 1003 1004func (s *subscription) deleteStream(st *stream) { 1005 s.mu.Lock() 1006 defer s.mu.Unlock() 1007 var i int 1008 for i = 0; i < len(s.streams); i++ { 1009 if s.streams[i] == st { 1010 break 1011 } 1012 } 1013 if i < len(s.streams) { 1014 s.streams = deleteStreamAt(s.streams, i) 1015 } 1016} 1017func deleteStreamAt(s []*stream, i int) []*stream { 1018 // Preserve order for round-robin delivery. 1019 return append(s[:i], s[i+1:]...) 1020} 1021 1022type message struct { 1023 proto *pb.ReceivedMessage 1024 publishTime time.Time 1025 ackDeadline time.Time 1026 deliveries *int 1027 acks *int 1028 streamIndex int // index of stream that currently owns msg, for round-robin delivery 1029} 1030 1031// A message is outstanding if it is owned by some stream. 1032func (m *message) outstanding() bool { 1033 return !m.ackDeadline.IsZero() 1034} 1035 1036func (m *message) makeAvailable() { 1037 m.ackDeadline = time.Time{} 1038} 1039 1040type stream struct { 1041 sub *subscription 1042 done chan struct{} // closed when the stream is finished 1043 msgc chan *pb.ReceivedMessage 1044 gstream pb.Subscriber_StreamingPullServer 1045 ackTimeout time.Duration 1046 timeout time.Duration 1047} 1048 1049// pull manages the StreamingPull interaction for the life of the stream. 1050func (st *stream) pull(wg *sync.WaitGroup) error { 1051 errc := make(chan error, 2) 1052 wg.Add(2) 1053 go func() { 1054 defer wg.Done() 1055 errc <- st.sendLoop() 1056 }() 1057 go func() { 1058 defer wg.Done() 1059 errc <- st.recvLoop() 1060 }() 1061 var tchan <-chan time.Time 1062 if st.timeout > 0 { 1063 tchan = time.After(st.timeout) 1064 } 1065 // Wait until one of the goroutines returns an error, or we time out. 1066 var err error 1067 select { 1068 case err = <-errc: 1069 if err == io.EOF { 1070 err = nil 1071 } 1072 case <-tchan: 1073 } 1074 close(st.done) // stop the other goroutine 1075 return err 1076} 1077 1078func (st *stream) sendLoop() error { 1079 for { 1080 select { 1081 case <-st.done: 1082 return nil 1083 case rm := <-st.msgc: 1084 res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}} 1085 if err := st.gstream.Send(res); err != nil { 1086 return err 1087 } 1088 } 1089 } 1090} 1091 1092func (st *stream) recvLoop() error { 1093 for { 1094 req, err := st.gstream.Recv() 1095 if err != nil { 1096 return err 1097 } 1098 st.sub.handleStreamingPullRequest(st, req) 1099 } 1100} 1101 1102func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) { 1103 // Lock the entire server. 1104 s.mu.Lock() 1105 defer s.mu.Unlock() 1106 1107 for _, ackID := range req.AckIds { 1108 s.ack(ackID) 1109 } 1110 for i, id := range req.ModifyDeadlineAckIds { 1111 s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i])) 1112 } 1113 if req.StreamAckDeadlineSeconds > 0 { 1114 st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds) 1115 } 1116} 1117 1118// Must be called with the lock held. 1119func (s *subscription) ack(id string) { 1120 m := s.msgs[id] 1121 if m != nil { 1122 (*m.acks)++ 1123 delete(s.msgs, id) 1124 } 1125} 1126 1127// Must be called with the lock held. 1128func (s *subscription) modifyAckDeadline(id string, d time.Duration) { 1129 m := s.msgs[id] 1130 if m == nil { // already acked: ignore. 1131 return 1132 } 1133 if d == 0 { // nack 1134 m.makeAvailable() 1135 } else { // extend the deadline by d 1136 m.ackDeadline = s.timeNowFunc().Add(d) 1137 } 1138} 1139 1140func secsToDur(secs int32) time.Duration { 1141 return time.Duration(secs) * time.Second 1142} 1143 1144// runReactor looks up the reactors for a function, then launches them until handled=true 1145// or err is returned. If the reactor returns nil, the function returns defaultObj instead. 1146func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) { 1147 if val, ok := s.reactorOptions[funcName]; ok { 1148 for _, reactor := range val { 1149 handled, ret, err := reactor.React(req) 1150 // If handled=true, that means the reactor has successfully reacted to the request, 1151 // so use the output directly. If err occurs, that means the request is invalidated 1152 // by the reactor somehow. 1153 if handled || err != nil { 1154 if ret == nil { 1155 ret = defaultObj 1156 } 1157 return true, ret, err 1158 } 1159 } 1160 } 1161 return false, nil, nil 1162} 1163 1164// errorInjectionReactor is a reactor to inject an error message with status code. 1165type errorInjectionReactor struct { 1166 code codes.Code 1167 msg string 1168} 1169 1170// React simply returns an error with defined error message and status code. 1171func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) { 1172 return true, nil, status.Errorf(e.code, e.msg) 1173} 1174 1175// WithErrorInjection creates a ServerReactorOption that injects error with defined status code and 1176// message for a certain function. 1177func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption { 1178 return ServerReactorOption{ 1179 FuncName: funcName, 1180 Reactor: &errorInjectionReactor{code: code, msg: msg}, 1181 } 1182} 1183