1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// Package pstest provides a fake Cloud PubSub service for testing. It implements a 16// simplified form of the service, suitable for unit tests. It may behave 17// differently from the actual service in ways in which the service is 18// non-deterministic or unspecified: timing, delivery order, etc. 19// 20// This package is EXPERIMENTAL and is subject to change without notice. 21// 22// See the example for usage. 23package pstest 24 25import ( 26 "context" 27 "fmt" 28 "io" 29 "path" 30 "sort" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "time" 35 36 "cloud.google.com/go/internal/testutil" 37 pb "google.golang.org/genproto/googleapis/pubsub/v1" 38 "google.golang.org/grpc/codes" 39 "google.golang.org/grpc/status" 40 durpb "google.golang.org/protobuf/types/known/durationpb" 41 "google.golang.org/protobuf/types/known/emptypb" 42 "google.golang.org/protobuf/types/known/timestamppb" 43) 44 45// ReactorOptions is a map that Server uses to look up reactors. 46// Key is the function name, value is array of reactor for the function. 47type ReactorOptions map[string][]Reactor 48 49// Reactor is an interface to allow reaction function to a certain call. 50type Reactor interface { 51 // React handles the message types and returns results. If "handled" is false, 52 // then the test server will ignore the results and continue to the next reactor 53 // or the original handler. 54 React(_ interface{}) (handled bool, ret interface{}, err error) 55} 56 57// ServerReactorOption is options passed to the server for reactor creation. 58type ServerReactorOption struct { 59 FuncName string 60 Reactor Reactor 61} 62 63type publishResponse struct { 64 resp *pb.PublishResponse 65 err error 66} 67 68// For testing. Note that even though changes to the now variable are atomic, a call 69// to the stored function can race with a change to that function. This could be a 70// problem if tests are run in parallel, or even if concurrent parts of the same test 71// change the value of the variable. 72var now atomic.Value 73 74func init() { 75 now.Store(time.Now) 76 ResetMinAckDeadline() 77} 78 79func timeNow() time.Time { 80 return now.Load().(func() time.Time)() 81} 82 83// Server is a fake Pub/Sub server. 84type Server struct { 85 srv *testutil.Server 86 Addr string // The address that the server is listening on. 87 GServer GServer // Not intended to be used directly. 88} 89 90// GServer is the underlying service implementor. It is not intended to be used 91// directly. 92type GServer struct { 93 pb.PublisherServer 94 pb.SubscriberServer 95 96 mu sync.Mutex 97 topics map[string]*topic 98 subs map[string]*subscription 99 msgs []*Message // all messages ever published 100 msgsByID map[string]*Message 101 wg sync.WaitGroup 102 nextID int 103 streamTimeout time.Duration 104 timeNowFunc func() time.Time 105 reactorOptions ReactorOptions 106 schemas map[string]*pb.Schema 107 108 // PublishResponses is a channel of responses to use for Publish. 109 publishResponses chan *publishResponse 110 // autoPublishResponse enables the server to automatically generate 111 // PublishResponse when publish is called. Otherwise, responses 112 // are generated from the publishResponses channel. 113 autoPublishResponse bool 114} 115 116// NewServer creates a new fake server running in the current process. 117func NewServer(opts ...ServerReactorOption) *Server { 118 return NewServerWithPort(0, opts...) 119} 120 121// NewServerWithPort creates a new fake server running in the current process at the specified port. 122func NewServerWithPort(port int, opts ...ServerReactorOption) *Server { 123 srv, err := testutil.NewServerWithPort(port) 124 if err != nil { 125 panic(fmt.Sprintf("pstest.NewServerWithPort: %v", err)) 126 } 127 reactorOptions := ReactorOptions{} 128 for _, opt := range opts { 129 reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor) 130 } 131 s := &Server{ 132 srv: srv, 133 Addr: srv.Addr, 134 GServer: GServer{ 135 topics: map[string]*topic{}, 136 subs: map[string]*subscription{}, 137 msgsByID: map[string]*Message{}, 138 timeNowFunc: timeNow, 139 reactorOptions: reactorOptions, 140 publishResponses: make(chan *publishResponse, 100), 141 autoPublishResponse: true, 142 schemas: map[string]*pb.Schema{}, 143 }, 144 } 145 pb.RegisterPublisherServer(srv.Gsrv, &s.GServer) 146 pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer) 147 pb.RegisterSchemaServiceServer(srv.Gsrv, &s.GServer) 148 srv.Start() 149 return s 150} 151 152// SetTimeNowFunc registers f as a function to 153// be used instead of time.Now for this server. 154func (s *Server) SetTimeNowFunc(f func() time.Time) { 155 s.GServer.timeNowFunc = f 156} 157 158// Publish behaves as if the Publish RPC was called with a message with the given 159// data and attrs. It returns the ID of the message. 160// The topic will be created if it doesn't exist. 161// 162// Publish panics if there is an error, which is appropriate for testing. 163func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string { 164 return s.PublishOrdered(topic, data, attrs, "") 165} 166 167// PublishOrdered behaves as if the Publish RPC was called with a message with the given 168// data, attrs and ordering key. It returns the ID of the message. 169// The topic will be created if it doesn't exist. 170// 171// PublishOrdered panics if there is an error, which is appropriate for testing. 172func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string { 173 const topicPattern = "projects/*/topics/*" 174 ok, err := path.Match(topicPattern, topic) 175 if err != nil { 176 panic(err) 177 } 178 if !ok { 179 panic(fmt.Sprintf("topic name must be of the form %q", topicPattern)) 180 } 181 _, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic}) 182 req := &pb.PublishRequest{ 183 Topic: topic, 184 Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}}, 185 } 186 res, err := s.GServer.Publish(context.TODO(), req) 187 if err != nil { 188 panic(fmt.Sprintf("pstest.Server.Publish: %v", err)) 189 } 190 return res.MessageIds[0] 191} 192 193// AddPublishResponse adds a new publish response to the channel used for 194// responding to publish requests. 195func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) { 196 pr := &publishResponse{} 197 if err != nil { 198 pr.err = err 199 } else { 200 pr.resp = pbr 201 } 202 s.GServer.publishResponses <- pr 203} 204 205// SetAutoPublishResponse controls whether to automatically respond 206// to messages published or to use user-added responses from the 207// publishResponses channel. 208func (s *Server) SetAutoPublishResponse(autoPublishResponse bool) { 209 s.GServer.mu.Lock() 210 defer s.GServer.mu.Unlock() 211 s.GServer.autoPublishResponse = autoPublishResponse 212} 213 214// ResetPublishResponses resets the buffered publishResponses channel 215// with a new buffered channel with the given size. 216func (s *Server) ResetPublishResponses(size int) { 217 s.GServer.mu.Lock() 218 defer s.GServer.mu.Unlock() 219 s.GServer.publishResponses = make(chan *publishResponse, size) 220} 221 222// SetStreamTimeout sets the amount of time a stream will be active before it shuts 223// itself down. This mimics the real service's behavior of closing streams after 30 224// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut 225// down. 226func (s *Server) SetStreamTimeout(d time.Duration) { 227 s.GServer.mu.Lock() 228 defer s.GServer.mu.Unlock() 229 s.GServer.streamTimeout = d 230} 231 232// A Message is a message that was published to the server. 233type Message struct { 234 ID string 235 Data []byte 236 Attributes map[string]string 237 PublishTime time.Time 238 Deliveries int // number of times delivery of the message was attempted 239 Acks int // number of acks received from clients 240 Modacks []Modack // modacks received by server for this message 241 OrderingKey string 242 243 // protected by server mutex 244 deliveries int 245 acks int 246 modacks []Modack 247} 248 249// Modack represents a modack sent to the server. 250type Modack struct { 251 AckID string 252 AckDeadline int32 253 ReceivedAt time.Time 254} 255 256// Messages returns information about all messages ever published. 257func (s *Server) Messages() []*Message { 258 s.GServer.mu.Lock() 259 defer s.GServer.mu.Unlock() 260 261 var msgs []*Message 262 for _, m := range s.GServer.msgs { 263 m.Deliveries = m.deliveries 264 m.Acks = m.acks 265 m.Modacks = append([]Modack(nil), m.modacks...) 266 msgs = append(msgs, m) 267 } 268 return msgs 269} 270 271// Message returns the message with the given ID, or nil if no message 272// with that ID was published. 273func (s *Server) Message(id string) *Message { 274 s.GServer.mu.Lock() 275 defer s.GServer.mu.Unlock() 276 277 m := s.GServer.msgsByID[id] 278 if m != nil { 279 m.Deliveries = m.deliveries 280 m.Acks = m.acks 281 m.Modacks = append([]Modack(nil), m.modacks...) 282 } 283 return m 284} 285 286// Wait blocks until all server activity has completed. 287func (s *Server) Wait() { 288 s.GServer.wg.Wait() 289} 290 291// ClearMessages removes all published messages 292// from internal containers. 293func (s *Server) ClearMessages() { 294 s.GServer.mu.Lock() 295 s.GServer.msgs = nil 296 s.GServer.msgsByID = make(map[string]*Message) 297 s.GServer.mu.Unlock() 298} 299 300// Close shuts down the server and releases all resources. 301func (s *Server) Close() error { 302 s.srv.Close() 303 s.GServer.mu.Lock() 304 defer s.GServer.mu.Unlock() 305 for _, sub := range s.GServer.subs { 306 sub.stop() 307 } 308 return nil 309} 310 311func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) { 312 s.mu.Lock() 313 defer s.mu.Unlock() 314 315 if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil { 316 return ret.(*pb.Topic), err 317 } 318 319 if s.topics[t.Name] != nil { 320 return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name) 321 } 322 top := newTopic(t) 323 s.topics[t.Name] = top 324 return top.proto, nil 325} 326 327func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) { 328 s.mu.Lock() 329 defer s.mu.Unlock() 330 331 if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil { 332 return ret.(*pb.Topic), err 333 } 334 335 if t := s.topics[req.Topic]; t != nil { 336 return t.proto, nil 337 } 338 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 339} 340 341func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) { 342 s.mu.Lock() 343 defer s.mu.Unlock() 344 345 if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil { 346 return ret.(*pb.Topic), err 347 } 348 349 t := s.topics[req.Topic.Name] 350 if t == nil { 351 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name) 352 } 353 for _, path := range req.UpdateMask.Paths { 354 switch path { 355 case "labels": 356 t.proto.Labels = req.Topic.Labels 357 case "message_storage_policy": 358 t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy 359 default: 360 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 361 } 362 } 363 return t.proto, nil 364} 365 366func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) { 367 s.mu.Lock() 368 defer s.mu.Unlock() 369 370 if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil { 371 return ret.(*pb.ListTopicsResponse), err 372 } 373 374 var names []string 375 for n := range s.topics { 376 if strings.HasPrefix(n, req.Project) { 377 names = append(names, n) 378 } 379 } 380 sort.Strings(names) 381 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 382 if err != nil { 383 return nil, err 384 } 385 res := &pb.ListTopicsResponse{NextPageToken: nextToken} 386 for i := from; i < to; i++ { 387 res.Topics = append(res.Topics, s.topics[names[i]].proto) 388 } 389 return res, nil 390} 391 392func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) { 393 s.mu.Lock() 394 defer s.mu.Unlock() 395 396 if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil { 397 return ret.(*pb.ListTopicSubscriptionsResponse), err 398 } 399 400 var names []string 401 for name, sub := range s.subs { 402 if sub.topic.proto.Name == req.Topic { 403 names = append(names, name) 404 } 405 } 406 sort.Strings(names) 407 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 408 if err != nil { 409 return nil, err 410 } 411 return &pb.ListTopicSubscriptionsResponse{ 412 Subscriptions: names[from:to], 413 NextPageToken: nextToken, 414 }, nil 415} 416 417func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) { 418 s.mu.Lock() 419 defer s.mu.Unlock() 420 421 if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil { 422 return ret.(*emptypb.Empty), err 423 } 424 425 t := s.topics[req.Topic] 426 if t == nil { 427 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 428 } 429 t.stop() 430 delete(s.topics, req.Topic) 431 return &emptypb.Empty{}, nil 432} 433 434func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) { 435 s.mu.Lock() 436 defer s.mu.Unlock() 437 438 if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil { 439 return ret.(*pb.Subscription), err 440 } 441 442 if ps.Name == "" { 443 return nil, status.Errorf(codes.InvalidArgument, "missing name") 444 } 445 if s.subs[ps.Name] != nil { 446 return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name) 447 } 448 if ps.Topic == "" { 449 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 450 } 451 top := s.topics[ps.Topic] 452 if top == nil { 453 return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic) 454 } 455 if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil { 456 return nil, err 457 } 458 if ps.MessageRetentionDuration == nil { 459 ps.MessageRetentionDuration = defaultMessageRetentionDuration 460 } 461 if err := checkMRD(ps.MessageRetentionDuration); err != nil { 462 return nil, err 463 } 464 if ps.PushConfig == nil { 465 ps.PushConfig = &pb.PushConfig{} 466 } 467 468 sub := newSubscription(top, &s.mu, s.timeNowFunc, ps) 469 top.subs[ps.Name] = sub 470 s.subs[ps.Name] = sub 471 sub.start(&s.wg) 472 return ps, nil 473} 474 475// Can be set for testing. 476var minAckDeadlineSecs int32 477 478// SetMinAckDeadline changes the minack deadline to n. Must be 479// greater than or equal to 1 second. Remember to reset this value 480// to the default after your test changes it. Example usage: 481// pstest.SetMinAckDeadlineSecs(1) 482// defer pstest.ResetMinAckDeadlineSecs() 483func SetMinAckDeadline(n time.Duration) { 484 if n < time.Second { 485 panic("SetMinAckDeadline expects a value greater than 1 second") 486 } 487 488 minAckDeadlineSecs = int32(n / time.Second) 489} 490 491// ResetMinAckDeadline resets the minack deadline to the default. 492func ResetMinAckDeadline() { 493 minAckDeadlineSecs = 10 494} 495 496func checkAckDeadline(ads int32) error { 497 if ads < minAckDeadlineSecs || ads > 600 { 498 // PubSub service returns Unknown. 499 return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads) 500 } 501 return nil 502} 503 504const ( 505 minMessageRetentionDuration = 10 * time.Minute 506 maxMessageRetentionDuration = 168 * time.Hour 507) 508 509var defaultMessageRetentionDuration = durpb.New(maxMessageRetentionDuration) 510 511func checkMRD(pmrd *durpb.Duration) error { 512 mrd := pmrd.AsDuration() 513 if mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration { 514 return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd) 515 } 516 return nil 517} 518 519func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { 520 s.mu.Lock() 521 defer s.mu.Unlock() 522 523 if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil { 524 return ret.(*pb.Subscription), err 525 } 526 527 sub, err := s.findSubscription(req.Subscription) 528 if err != nil { 529 return nil, err 530 } 531 return sub.proto, nil 532} 533 534func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) { 535 if req.Subscription == nil { 536 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 537 } 538 s.mu.Lock() 539 defer s.mu.Unlock() 540 541 if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil { 542 return ret.(*pb.Subscription), err 543 } 544 545 sub, err := s.findSubscription(req.Subscription.Name) 546 if err != nil { 547 return nil, err 548 } 549 for _, path := range req.UpdateMask.Paths { 550 switch path { 551 case "push_config": 552 sub.proto.PushConfig = req.Subscription.PushConfig 553 554 case "ack_deadline_seconds": 555 a := req.Subscription.AckDeadlineSeconds 556 if err := checkAckDeadline(a); err != nil { 557 return nil, err 558 } 559 sub.proto.AckDeadlineSeconds = a 560 561 case "retain_acked_messages": 562 sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages 563 564 case "message_retention_duration": 565 if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil { 566 return nil, err 567 } 568 sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration 569 570 case "labels": 571 sub.proto.Labels = req.Subscription.Labels 572 573 case "expiration_policy": 574 sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy 575 576 case "dead_letter_policy": 577 sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy 578 579 case "retry_policy": 580 sub.proto.RetryPolicy = req.Subscription.RetryPolicy 581 582 case "filter": 583 sub.proto.Filter = req.Subscription.Filter 584 585 default: 586 return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path) 587 } 588 } 589 return sub.proto, nil 590} 591 592func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) { 593 s.mu.Lock() 594 defer s.mu.Unlock() 595 596 if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil { 597 return ret.(*pb.ListSubscriptionsResponse), err 598 } 599 600 var names []string 601 for name := range s.subs { 602 if strings.HasPrefix(name, req.Project) { 603 names = append(names, name) 604 } 605 } 606 sort.Strings(names) 607 from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names)) 608 if err != nil { 609 return nil, err 610 } 611 res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken} 612 for i := from; i < to; i++ { 613 res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto) 614 } 615 return res, nil 616} 617 618func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) { 619 s.mu.Lock() 620 defer s.mu.Unlock() 621 622 if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil { 623 return ret.(*emptypb.Empty), err 624 } 625 626 sub, err := s.findSubscription(req.Subscription) 627 if err != nil { 628 return nil, err 629 } 630 sub.stop() 631 delete(s.subs, req.Subscription) 632 sub.topic.deleteSub(sub) 633 return &emptypb.Empty{}, nil 634} 635 636func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) { 637 s.mu.Lock() 638 defer s.mu.Unlock() 639 640 if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil { 641 return ret.(*pb.DetachSubscriptionResponse), err 642 } 643 644 sub, err := s.findSubscription(req.Subscription) 645 if err != nil { 646 return nil, err 647 } 648 sub.topic.deleteSub(sub) 649 return &pb.DetachSubscriptionResponse{}, nil 650} 651 652func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) { 653 s.mu.Lock() 654 defer s.mu.Unlock() 655 656 if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil { 657 return ret.(*pb.PublishResponse), err 658 } 659 660 if req.Topic == "" { 661 return nil, status.Errorf(codes.InvalidArgument, "missing topic") 662 } 663 top := s.topics[req.Topic] 664 if top == nil { 665 return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic) 666 } 667 668 if !s.autoPublishResponse { 669 r := <-s.publishResponses 670 if r.err != nil { 671 return nil, r.err 672 } 673 return r.resp, nil 674 } 675 676 var ids []string 677 for _, pm := range req.Messages { 678 id := fmt.Sprintf("m%d", s.nextID) 679 s.nextID++ 680 pm.MessageId = id 681 pubTime := s.timeNowFunc() 682 tsPubTime := timestamppb.New(pubTime) 683 pm.PublishTime = tsPubTime 684 m := &Message{ 685 ID: id, 686 Data: pm.Data, 687 Attributes: pm.Attributes, 688 PublishTime: pubTime, 689 OrderingKey: pm.OrderingKey, 690 } 691 top.publish(pm, m) 692 ids = append(ids, id) 693 s.msgs = append(s.msgs, m) 694 s.msgsByID[id] = m 695 } 696 return &pb.PublishResponse{MessageIds: ids}, nil 697} 698 699type topic struct { 700 proto *pb.Topic 701 subs map[string]*subscription 702} 703 704func newTopic(pt *pb.Topic) *topic { 705 return &topic{ 706 proto: pt, 707 subs: map[string]*subscription{}, 708 } 709} 710 711func (t *topic) stop() { 712 for _, sub := range t.subs { 713 sub.proto.Topic = "_deleted-topic_" 714 } 715} 716 717func (t *topic) deleteSub(sub *subscription) { 718 delete(t.subs, sub.proto.Name) 719} 720 721func (t *topic) publish(pm *pb.PubsubMessage, m *Message) { 722 for _, s := range t.subs { 723 s.msgs[pm.MessageId] = &message{ 724 publishTime: m.PublishTime, 725 proto: &pb.ReceivedMessage{ 726 AckId: pm.MessageId, 727 Message: pm, 728 }, 729 deliveries: &m.deliveries, 730 acks: &m.acks, 731 streamIndex: -1, 732 } 733 } 734} 735 736type subscription struct { 737 topic *topic 738 mu *sync.Mutex // the server mutex, here for convenience 739 proto *pb.Subscription 740 ackTimeout time.Duration 741 msgs map[string]*message // unacked messages by message ID 742 streams []*stream 743 done chan struct{} 744 timeNowFunc func() time.Time 745} 746 747func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, ps *pb.Subscription) *subscription { 748 at := time.Duration(ps.AckDeadlineSeconds) * time.Second 749 if at == 0 { 750 at = 10 * time.Second 751 } 752 return &subscription{ 753 topic: t, 754 mu: mu, 755 proto: ps, 756 ackTimeout: at, 757 msgs: map[string]*message{}, 758 done: make(chan struct{}), 759 timeNowFunc: timeNowFunc, 760 } 761} 762 763func (s *subscription) start(wg *sync.WaitGroup) { 764 wg.Add(1) 765 go func() { 766 defer wg.Done() 767 for { 768 select { 769 case <-s.done: 770 return 771 case <-time.After(10 * time.Millisecond): 772 s.deliver() 773 } 774 } 775 }() 776} 777 778func (s *subscription) stop() { 779 close(s.done) 780} 781 782func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { 783 s.mu.Lock() 784 defer s.mu.Unlock() 785 786 if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil { 787 return ret.(*emptypb.Empty), err 788 } 789 790 sub, err := s.findSubscription(req.Subscription) 791 if err != nil { 792 return nil, err 793 } 794 for _, id := range req.AckIds { 795 sub.ack(id) 796 } 797 return &emptypb.Empty{}, nil 798} 799 800func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { 801 s.mu.Lock() 802 defer s.mu.Unlock() 803 804 if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil { 805 return ret.(*emptypb.Empty), err 806 } 807 808 sub, err := s.findSubscription(req.Subscription) 809 if err != nil { 810 return nil, err 811 } 812 now := time.Now() 813 for _, id := range req.AckIds { 814 s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now}) 815 } 816 dur := secsToDur(req.AckDeadlineSeconds) 817 for _, id := range req.AckIds { 818 sub.modifyAckDeadline(id, dur) 819 } 820 return &emptypb.Empty{}, nil 821} 822 823func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) { 824 s.mu.Lock() 825 826 if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil { 827 s.mu.Unlock() 828 return ret.(*pb.PullResponse), err 829 } 830 831 sub, err := s.findSubscription(req.Subscription) 832 if err != nil { 833 s.mu.Unlock() 834 return nil, err 835 } 836 max := int(req.MaxMessages) 837 if max < 0 { 838 s.mu.Unlock() 839 return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative") 840 } 841 if max == 0 { // MaxMessages not specified; use a default. 842 max = 1000 843 } 844 msgs := sub.pull(max) 845 s.mu.Unlock() 846 // Implement the spec from the pubsub proto: 847 // "If ReturnImmediately set to true, the system will respond immediately even if 848 // it there are no messages available to return in the `Pull` response. 849 // Otherwise, the system may wait (for a bounded amount of time) until at 850 // least one message is available, rather than returning no messages." 851 if len(msgs) == 0 && !req.ReturnImmediately { 852 // Wait for a short amount of time for a message. 853 // TODO: signal when a message arrives, so we don't wait the whole time. 854 select { 855 case <-ctx.Done(): 856 return nil, ctx.Err() 857 case <-time.After(500 * time.Millisecond): 858 s.mu.Lock() 859 msgs = sub.pull(max) 860 s.mu.Unlock() 861 } 862 } 863 return &pb.PullResponse{ReceivedMessages: msgs}, nil 864} 865 866func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error { 867 // Receive initial message configuring the pull. 868 req, err := sps.Recv() 869 if err != nil { 870 return err 871 } 872 s.mu.Lock() 873 sub, err := s.findSubscription(req.Subscription) 874 s.mu.Unlock() 875 if err != nil { 876 return err 877 } 878 // Create a new stream to handle the pull. 879 st := sub.newStream(sps, s.streamTimeout) 880 err = st.pull(&s.wg) 881 sub.deleteStream(st) 882 return err 883} 884 885func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) { 886 // Only handle time-based seeking for now. 887 // This fake doesn't deal with snapshots. 888 var target time.Time 889 switch v := req.Target.(type) { 890 case nil: 891 return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type") 892 case *pb.SeekRequest_Time: 893 target = v.Time.AsTime() 894 default: 895 return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v) 896 } 897 898 // The entire server must be locked while doing the work below, 899 // because the messages don't have any other synchronization. 900 s.mu.Lock() 901 defer s.mu.Unlock() 902 903 if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil { 904 return ret.(*pb.SeekResponse), err 905 } 906 907 sub, err := s.findSubscription(req.Subscription) 908 if err != nil { 909 return nil, err 910 } 911 // Drop all messages from sub that were published before the target time. 912 for id, m := range sub.msgs { 913 if m.publishTime.Before(target) { 914 delete(sub.msgs, id) 915 (*m.acks)++ 916 } 917 } 918 // Un-ack any already-acked messages after this time; 919 // redelivering them to the subscription is the closest analogue here. 920 for _, m := range s.msgs { 921 if m.PublishTime.Before(target) { 922 continue 923 } 924 sub.msgs[m.ID] = &message{ 925 publishTime: m.PublishTime, 926 proto: &pb.ReceivedMessage{ 927 AckId: m.ID, 928 // This was not preserved! 929 //Message: pm, 930 }, 931 deliveries: &m.deliveries, 932 acks: &m.acks, 933 streamIndex: -1, 934 } 935 } 936 return &pb.SeekResponse{}, nil 937} 938 939// Gets a subscription that must exist. 940// Must be called with the lock held. 941func (s *GServer) findSubscription(name string) (*subscription, error) { 942 if name == "" { 943 return nil, status.Errorf(codes.InvalidArgument, "missing subscription") 944 } 945 sub := s.subs[name] 946 if sub == nil { 947 return nil, status.Errorf(codes.NotFound, "subscription %s", name) 948 } 949 return sub, nil 950} 951 952// Must be called with the lock held. 953func (s *subscription) pull(max int) []*pb.ReceivedMessage { 954 now := s.timeNowFunc() 955 s.maintainMessages(now) 956 var msgs []*pb.ReceivedMessage 957 for _, m := range s.msgs { 958 if m.outstanding() { 959 continue 960 } 961 (*m.deliveries)++ 962 m.ackDeadline = now.Add(s.ackTimeout) 963 msgs = append(msgs, m.proto) 964 if len(msgs) >= max { 965 break 966 } 967 } 968 return msgs 969} 970 971func (s *subscription) deliver() { 972 s.mu.Lock() 973 defer s.mu.Unlock() 974 975 now := s.timeNowFunc() 976 s.maintainMessages(now) 977 // Try to deliver each remaining message. 978 curIndex := 0 979 for _, m := range s.msgs { 980 if m.outstanding() { 981 continue 982 } 983 // If the message was never delivered before, start with the stream at 984 // curIndex. If it was delivered before, start with the stream after the one 985 // that owned it. 986 if m.streamIndex < 0 { 987 delIndex, ok := s.tryDeliverMessage(m, curIndex, now) 988 if !ok { 989 break 990 } 991 curIndex = delIndex + 1 992 m.streamIndex = curIndex 993 } else { 994 delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now) 995 if !ok { 996 break 997 } 998 m.streamIndex = delIndex 999 } 1000 } 1001} 1002 1003// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it 1004// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it 1005// exits. 1006// 1007// It returns the index of the stream it delivered the message to, or 0, false if 1008// it didn't deliver the message. 1009// 1010// Must be called with the lock held. 1011func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) { 1012 for i := 0; i < len(s.streams); i++ { 1013 idx := (i + start) % len(s.streams) 1014 1015 st := s.streams[idx] 1016 select { 1017 case <-st.done: 1018 s.streams = deleteStreamAt(s.streams, idx) 1019 i-- 1020 1021 case st.msgc <- m.proto: 1022 (*m.deliveries)++ 1023 m.ackDeadline = now.Add(st.ackTimeout) 1024 return idx, true 1025 1026 default: 1027 } 1028 } 1029 return 0, false 1030} 1031 1032var retentionDuration = 10 * time.Minute 1033 1034// Must be called with the lock held. 1035func (s *subscription) maintainMessages(now time.Time) { 1036 for id, m := range s.msgs { 1037 // Mark a message as re-deliverable if its ack deadline has expired. 1038 if m.outstanding() && now.After(m.ackDeadline) { 1039 m.makeAvailable() 1040 } 1041 pubTime := m.proto.Message.PublishTime.AsTime() 1042 // Remove messages that have been undelivered for a long time. 1043 if !m.outstanding() && now.Sub(pubTime) > retentionDuration { 1044 delete(s.msgs, id) 1045 } 1046 } 1047} 1048 1049func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream { 1050 st := &stream{ 1051 sub: s, 1052 done: make(chan struct{}), 1053 msgc: make(chan *pb.ReceivedMessage), 1054 gstream: gs, 1055 ackTimeout: s.ackTimeout, 1056 timeout: timeout, 1057 } 1058 s.mu.Lock() 1059 s.streams = append(s.streams, st) 1060 s.mu.Unlock() 1061 return st 1062} 1063 1064func (s *subscription) deleteStream(st *stream) { 1065 s.mu.Lock() 1066 defer s.mu.Unlock() 1067 var i int 1068 for i = 0; i < len(s.streams); i++ { 1069 if s.streams[i] == st { 1070 break 1071 } 1072 } 1073 if i < len(s.streams) { 1074 s.streams = deleteStreamAt(s.streams, i) 1075 } 1076} 1077func deleteStreamAt(s []*stream, i int) []*stream { 1078 // Preserve order for round-robin delivery. 1079 return append(s[:i], s[i+1:]...) 1080} 1081 1082type message struct { 1083 proto *pb.ReceivedMessage 1084 publishTime time.Time 1085 ackDeadline time.Time 1086 deliveries *int 1087 acks *int 1088 streamIndex int // index of stream that currently owns msg, for round-robin delivery 1089} 1090 1091// A message is outstanding if it is owned by some stream. 1092func (m *message) outstanding() bool { 1093 return !m.ackDeadline.IsZero() 1094} 1095 1096func (m *message) makeAvailable() { 1097 m.ackDeadline = time.Time{} 1098} 1099 1100type stream struct { 1101 sub *subscription 1102 done chan struct{} // closed when the stream is finished 1103 msgc chan *pb.ReceivedMessage 1104 gstream pb.Subscriber_StreamingPullServer 1105 ackTimeout time.Duration 1106 timeout time.Duration 1107} 1108 1109// pull manages the StreamingPull interaction for the life of the stream. 1110func (st *stream) pull(wg *sync.WaitGroup) error { 1111 errc := make(chan error, 2) 1112 wg.Add(2) 1113 go func() { 1114 defer wg.Done() 1115 errc <- st.sendLoop() 1116 }() 1117 go func() { 1118 defer wg.Done() 1119 errc <- st.recvLoop() 1120 }() 1121 var tchan <-chan time.Time 1122 if st.timeout > 0 { 1123 tchan = time.After(st.timeout) 1124 } 1125 // Wait until one of the goroutines returns an error, or we time out. 1126 var err error 1127 select { 1128 case err = <-errc: 1129 if err == io.EOF { 1130 err = nil 1131 } 1132 case <-tchan: 1133 } 1134 close(st.done) // stop the other goroutine 1135 return err 1136} 1137 1138func (st *stream) sendLoop() error { 1139 for { 1140 select { 1141 case <-st.done: 1142 return nil 1143 case rm := <-st.msgc: 1144 res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}} 1145 if err := st.gstream.Send(res); err != nil { 1146 return err 1147 } 1148 } 1149 } 1150} 1151 1152func (st *stream) recvLoop() error { 1153 for { 1154 req, err := st.gstream.Recv() 1155 if err != nil { 1156 return err 1157 } 1158 st.sub.handleStreamingPullRequest(st, req) 1159 } 1160} 1161 1162func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) { 1163 // Lock the entire server. 1164 s.mu.Lock() 1165 defer s.mu.Unlock() 1166 1167 for _, ackID := range req.AckIds { 1168 s.ack(ackID) 1169 } 1170 for i, id := range req.ModifyDeadlineAckIds { 1171 s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i])) 1172 } 1173 if req.StreamAckDeadlineSeconds > 0 { 1174 st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds) 1175 } 1176} 1177 1178// Must be called with the lock held. 1179func (s *subscription) ack(id string) { 1180 m := s.msgs[id] 1181 if m != nil { 1182 (*m.acks)++ 1183 delete(s.msgs, id) 1184 } 1185} 1186 1187// Must be called with the lock held. 1188func (s *subscription) modifyAckDeadline(id string, d time.Duration) { 1189 m := s.msgs[id] 1190 if m == nil { // already acked: ignore. 1191 return 1192 } 1193 if d == 0 { // nack 1194 m.makeAvailable() 1195 } else { // extend the deadline by d 1196 m.ackDeadline = s.timeNowFunc().Add(d) 1197 } 1198} 1199 1200func secsToDur(secs int32) time.Duration { 1201 return time.Duration(secs) * time.Second 1202} 1203 1204// runReactor looks up the reactors for a function, then launches them until handled=true 1205// or err is returned. If the reactor returns nil, the function returns defaultObj instead. 1206func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) { 1207 if val, ok := s.reactorOptions[funcName]; ok { 1208 for _, reactor := range val { 1209 handled, ret, err := reactor.React(req) 1210 // If handled=true, that means the reactor has successfully reacted to the request, 1211 // so use the output directly. If err occurs, that means the request is invalidated 1212 // by the reactor somehow. 1213 if handled || err != nil { 1214 if ret == nil { 1215 ret = defaultObj 1216 } 1217 return true, ret, err 1218 } 1219 } 1220 } 1221 return false, nil, nil 1222} 1223 1224// errorInjectionReactor is a reactor to inject an error message with status code. 1225type errorInjectionReactor struct { 1226 code codes.Code 1227 msg string 1228} 1229 1230// React simply returns an error with defined error message and status code. 1231func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) { 1232 return true, nil, status.Errorf(e.code, e.msg) 1233} 1234 1235// WithErrorInjection creates a ServerReactorOption that injects error with defined status code and 1236// message for a certain function. 1237func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption { 1238 return ServerReactorOption{ 1239 FuncName: funcName, 1240 Reactor: &errorInjectionReactor{code: code, msg: msg}, 1241 } 1242} 1243 1244func (s *GServer) CreateSchema(_ context.Context, req *pb.CreateSchemaRequest) (*pb.Schema, error) { 1245 s.mu.Lock() 1246 defer s.mu.Unlock() 1247 1248 if handled, ret, err := s.runReactor(req, "CreateSchema", &pb.Schema{}); handled || err != nil { 1249 return ret.(*pb.Schema), err 1250 } 1251 1252 name := fmt.Sprintf("%s/schemas/%s", req.Parent, req.SchemaId) 1253 sc := &pb.Schema{ 1254 Name: name, 1255 Type: req.Schema.Type, 1256 Definition: req.Schema.Definition, 1257 } 1258 s.schemas[name] = sc 1259 1260 return sc, nil 1261} 1262 1263func (s *GServer) GetSchema(_ context.Context, req *pb.GetSchemaRequest) (*pb.Schema, error) { 1264 1265 s.mu.Lock() 1266 defer s.mu.Unlock() 1267 1268 if handled, ret, err := s.runReactor(req, "GetSchema", &pb.Schema{}); handled || err != nil { 1269 return ret.(*pb.Schema), err 1270 } 1271 1272 sc, ok := s.schemas[req.Name] 1273 if !ok { 1274 return nil, status.Errorf(codes.NotFound, "schema(%q) not found", req.Name) 1275 } 1276 return sc, nil 1277} 1278 1279func (s *GServer) ListSchemas(_ context.Context, req *pb.ListSchemasRequest) (*pb.ListSchemasResponse, error) { 1280 s.mu.Lock() 1281 defer s.mu.Unlock() 1282 1283 if handled, ret, err := s.runReactor(req, "ListSchemas", &pb.ListSchemasResponse{}); handled || err != nil { 1284 return ret.(*pb.ListSchemasResponse), err 1285 } 1286 ss := make([]*pb.Schema, 0) 1287 for _, sc := range s.schemas { 1288 ss = append(ss, sc) 1289 } 1290 return &pb.ListSchemasResponse{ 1291 Schemas: ss, 1292 }, nil 1293} 1294 1295func (s *GServer) DeleteSchema(_ context.Context, req *pb.DeleteSchemaRequest) (*emptypb.Empty, error) { 1296 s.mu.Lock() 1297 defer s.mu.Unlock() 1298 1299 if handled, ret, err := s.runReactor(req, "DeleteSchema", &emptypb.Empty{}); handled || err != nil { 1300 return ret.(*emptypb.Empty), err 1301 } 1302 1303 schema := s.schemas[req.Name] 1304 if schema == nil { 1305 return nil, status.Errorf(codes.NotFound, "schema %q", req.Name) 1306 } 1307 1308 delete(s.schemas, req.Name) 1309 return &emptypb.Empty{}, nil 1310} 1311 1312// ValidateSchema mocks the ValidateSchema call but only checks that the schema definition is not empty. 1313func (s *GServer) ValidateSchema(_ context.Context, req *pb.ValidateSchemaRequest) (*pb.ValidateSchemaResponse, error) { 1314 s.mu.Lock() 1315 defer s.mu.Unlock() 1316 1317 if handled, ret, err := s.runReactor(req, "ValidateSchema", &pb.ValidateSchemaResponse{}); handled || err != nil { 1318 return ret.(*pb.ValidateSchemaResponse), err 1319 } 1320 1321 if req.Schema.Definition == "" { 1322 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty") 1323 } 1324 return &pb.ValidateSchemaResponse{}, nil 1325} 1326 1327// ValidateMessage mocks the ValidateMessage call but only checks that the schema definition to validate the 1328// message against is not empty. 1329func (s *GServer) ValidateMessage(_ context.Context, req *pb.ValidateMessageRequest) (*pb.ValidateMessageResponse, error) { 1330 s.mu.Lock() 1331 defer s.mu.Unlock() 1332 1333 if handled, ret, err := s.runReactor(req, "ValidateMessage", &pb.ValidateMessageResponse{}); handled || err != nil { 1334 return ret.(*pb.ValidateMessageResponse), err 1335 } 1336 1337 spec := req.GetSchemaSpec() 1338 if valReq, ok := spec.(*pb.ValidateMessageRequest_Name); ok { 1339 sc, ok := s.schemas[valReq.Name] 1340 if !ok { 1341 return nil, status.Errorf(codes.NotFound, "schema(%q) not found", valReq.Name) 1342 } 1343 if sc.Definition == "" { 1344 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty") 1345 } 1346 } 1347 if valReq, ok := spec.(*pb.ValidateMessageRequest_Schema); ok { 1348 if valReq.Schema.Definition == "" { 1349 return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty") 1350 } 1351 } 1352 1353 return &pb.ValidateMessageResponse{}, nil 1354} 1355