1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package pstest provides a fake Cloud PubSub service for testing. It implements a
16// simplified form of the service, suitable for unit tests. It may behave
17// differently from the actual service in ways in which the service is
18// non-deterministic or unspecified: timing, delivery order, etc.
19//
20// This package is EXPERIMENTAL and is subject to change without notice.
21//
22// See the example for usage.
23package pstest
24
25import (
26	"context"
27	"fmt"
28	"io"
29	"path"
30	"sort"
31	"strings"
32	"sync"
33	"sync/atomic"
34	"time"
35
36	"cloud.google.com/go/internal/testutil"
37	pb "google.golang.org/genproto/googleapis/pubsub/v1"
38	"google.golang.org/grpc/codes"
39	"google.golang.org/grpc/status"
40	durpb "google.golang.org/protobuf/types/known/durationpb"
41	"google.golang.org/protobuf/types/known/emptypb"
42	"google.golang.org/protobuf/types/known/timestamppb"
43)
44
45// ReactorOptions is a map that Server uses to look up reactors.
46// Key is the function name, value is array of reactor for the function.
47type ReactorOptions map[string][]Reactor
48
49// Reactor is an interface to allow reaction function to a certain call.
50type Reactor interface {
51	// React handles the message types and returns results.  If "handled" is false,
52	// then the test server will ignore the results and continue to the next reactor
53	// or the original handler.
54	React(_ interface{}) (handled bool, ret interface{}, err error)
55}
56
57// ServerReactorOption is options passed to the server for reactor creation.
58type ServerReactorOption struct {
59	FuncName string
60	Reactor  Reactor
61}
62
63// For testing. Note that even though changes to the now variable are atomic, a call
64// to the stored function can race with a change to that function. This could be a
65// problem if tests are run in parallel, or even if concurrent parts of the same test
66// change the value of the variable.
67var now atomic.Value
68
69func init() {
70	now.Store(time.Now)
71	ResetMinAckDeadline()
72}
73
74func timeNow() time.Time {
75	return now.Load().(func() time.Time)()
76}
77
78// Server is a fake Pub/Sub server.
79type Server struct {
80	srv     *testutil.Server
81	Addr    string  // The address that the server is listening on.
82	GServer GServer // Not intended to be used directly.
83}
84
85// GServer is the underlying service implementor. It is not intended to be used
86// directly.
87type GServer struct {
88	pb.PublisherServer
89	pb.SubscriberServer
90
91	mu             sync.Mutex
92	topics         map[string]*topic
93	subs           map[string]*subscription
94	msgs           []*Message // all messages ever published
95	msgsByID       map[string]*Message
96	wg             sync.WaitGroup
97	nextID         int
98	streamTimeout  time.Duration
99	timeNowFunc    func() time.Time
100	reactorOptions ReactorOptions
101}
102
103// NewServer creates a new fake server running in the current process.
104func NewServer(opts ...ServerReactorOption) *Server {
105	srv, err := testutil.NewServer()
106	if err != nil {
107		panic(fmt.Sprintf("pstest.NewServer: %v", err))
108	}
109	reactorOptions := ReactorOptions{}
110	for _, opt := range opts {
111		reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
112	}
113	s := &Server{
114		srv:  srv,
115		Addr: srv.Addr,
116		GServer: GServer{
117			topics:         map[string]*topic{},
118			subs:           map[string]*subscription{},
119			msgsByID:       map[string]*Message{},
120			timeNowFunc:    timeNow,
121			reactorOptions: reactorOptions,
122		},
123	}
124	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
125	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
126	srv.Start()
127	return s
128}
129
130// SetTimeNowFunc registers f as a function to
131// be used instead of time.Now for this server.
132func (s *Server) SetTimeNowFunc(f func() time.Time) {
133	s.GServer.timeNowFunc = f
134}
135
136// Publish behaves as if the Publish RPC was called with a message with the given
137// data and attrs. It returns the ID of the message.
138// The topic will be created if it doesn't exist.
139//
140// Publish panics if there is an error, which is appropriate for testing.
141func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
142	return s.PublishOrdered(topic, data, attrs, "")
143}
144
145// PublishOrdered behaves as if the Publish RPC was called with a message with the given
146// data, attrs and ordering key. It returns the ID of the message.
147// The topic will be created if it doesn't exist.
148//
149// PublishOrdered panics if there is an error, which is appropriate for testing.
150func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string {
151	const topicPattern = "projects/*/topics/*"
152	ok, err := path.Match(topicPattern, topic)
153	if err != nil {
154		panic(err)
155	}
156	if !ok {
157		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
158	}
159	_, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
160	req := &pb.PublishRequest{
161		Topic:    topic,
162		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}},
163	}
164	res, err := s.GServer.Publish(context.TODO(), req)
165	if err != nil {
166		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
167	}
168	return res.MessageIds[0]
169}
170
171// SetStreamTimeout sets the amount of time a stream will be active before it shuts
172// itself down. This mimics the real service's behavior of closing streams after 30
173// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
174// down.
175func (s *Server) SetStreamTimeout(d time.Duration) {
176	s.GServer.mu.Lock()
177	defer s.GServer.mu.Unlock()
178	s.GServer.streamTimeout = d
179}
180
181// A Message is a message that was published to the server.
182type Message struct {
183	ID          string
184	Data        []byte
185	Attributes  map[string]string
186	PublishTime time.Time
187	Deliveries  int      // number of times delivery of the message was attempted
188	Acks        int      // number of acks received from clients
189	Modacks     []Modack // modacks received by server for this message
190	OrderingKey string
191
192	// protected by server mutex
193	deliveries int
194	acks       int
195	modacks    []Modack
196}
197
198// Modack represents a modack sent to the server.
199type Modack struct {
200	AckID       string
201	AckDeadline int32
202	ReceivedAt  time.Time
203}
204
205// Messages returns information about all messages ever published.
206func (s *Server) Messages() []*Message {
207	s.GServer.mu.Lock()
208	defer s.GServer.mu.Unlock()
209
210	var msgs []*Message
211	for _, m := range s.GServer.msgs {
212		m.Deliveries = m.deliveries
213		m.Acks = m.acks
214		m.Modacks = append([]Modack(nil), m.modacks...)
215		msgs = append(msgs, m)
216	}
217	return msgs
218}
219
220// Message returns the message with the given ID, or nil if no message
221// with that ID was published.
222func (s *Server) Message(id string) *Message {
223	s.GServer.mu.Lock()
224	defer s.GServer.mu.Unlock()
225
226	m := s.GServer.msgsByID[id]
227	if m != nil {
228		m.Deliveries = m.deliveries
229		m.Acks = m.acks
230		m.Modacks = append([]Modack(nil), m.modacks...)
231	}
232	return m
233}
234
235// Wait blocks until all server activity has completed.
236func (s *Server) Wait() {
237	s.GServer.wg.Wait()
238}
239
240// ClearMessages removes all published messages
241// from internal containers.
242func (s *Server) ClearMessages() {
243	s.GServer.mu.Lock()
244	s.GServer.msgs = nil
245	s.GServer.msgsByID = make(map[string]*Message)
246	s.GServer.mu.Unlock()
247}
248
249// Close shuts down the server and releases all resources.
250func (s *Server) Close() error {
251	s.srv.Close()
252	s.GServer.mu.Lock()
253	defer s.GServer.mu.Unlock()
254	for _, sub := range s.GServer.subs {
255		sub.stop()
256	}
257	return nil
258}
259
260func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
261	s.mu.Lock()
262	defer s.mu.Unlock()
263
264	if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
265		return ret.(*pb.Topic), err
266	}
267
268	if s.topics[t.Name] != nil {
269		return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
270	}
271	top := newTopic(t)
272	s.topics[t.Name] = top
273	return top.proto, nil
274}
275
276func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
277	s.mu.Lock()
278	defer s.mu.Unlock()
279
280	if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
281		return ret.(*pb.Topic), err
282	}
283
284	if t := s.topics[req.Topic]; t != nil {
285		return t.proto, nil
286	}
287	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
288}
289
290func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
291	s.mu.Lock()
292	defer s.mu.Unlock()
293
294	if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
295		return ret.(*pb.Topic), err
296	}
297
298	t := s.topics[req.Topic.Name]
299	if t == nil {
300		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
301	}
302	for _, path := range req.UpdateMask.Paths {
303		switch path {
304		case "labels":
305			t.proto.Labels = req.Topic.Labels
306		case "message_storage_policy":
307			t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy
308		default:
309			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
310		}
311	}
312	return t.proto, nil
313}
314
315func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
316	s.mu.Lock()
317	defer s.mu.Unlock()
318
319	if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
320		return ret.(*pb.ListTopicsResponse), err
321	}
322
323	var names []string
324	for n := range s.topics {
325		if strings.HasPrefix(n, req.Project) {
326			names = append(names, n)
327		}
328	}
329	sort.Strings(names)
330	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
331	if err != nil {
332		return nil, err
333	}
334	res := &pb.ListTopicsResponse{NextPageToken: nextToken}
335	for i := from; i < to; i++ {
336		res.Topics = append(res.Topics, s.topics[names[i]].proto)
337	}
338	return res, nil
339}
340
341func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
342	s.mu.Lock()
343	defer s.mu.Unlock()
344
345	if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
346		return ret.(*pb.ListTopicSubscriptionsResponse), err
347	}
348
349	var names []string
350	for name, sub := range s.subs {
351		if sub.topic.proto.Name == req.Topic {
352			names = append(names, name)
353		}
354	}
355	sort.Strings(names)
356	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
357	if err != nil {
358		return nil, err
359	}
360	return &pb.ListTopicSubscriptionsResponse{
361		Subscriptions: names[from:to],
362		NextPageToken: nextToken,
363	}, nil
364}
365
366func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
367	s.mu.Lock()
368	defer s.mu.Unlock()
369
370	if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
371		return ret.(*emptypb.Empty), err
372	}
373
374	t := s.topics[req.Topic]
375	if t == nil {
376		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
377	}
378	t.stop()
379	delete(s.topics, req.Topic)
380	return &emptypb.Empty{}, nil
381}
382
383func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
384	s.mu.Lock()
385	defer s.mu.Unlock()
386
387	if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
388		return ret.(*pb.Subscription), err
389	}
390
391	if ps.Name == "" {
392		return nil, status.Errorf(codes.InvalidArgument, "missing name")
393	}
394	if s.subs[ps.Name] != nil {
395		return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
396	}
397	if ps.Topic == "" {
398		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
399	}
400	top := s.topics[ps.Topic]
401	if top == nil {
402		return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
403	}
404	if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
405		return nil, err
406	}
407	if ps.MessageRetentionDuration == nil {
408		ps.MessageRetentionDuration = defaultMessageRetentionDuration
409	}
410	if err := checkMRD(ps.MessageRetentionDuration); err != nil {
411		return nil, err
412	}
413	if ps.PushConfig == nil {
414		ps.PushConfig = &pb.PushConfig{}
415	}
416
417	sub := newSubscription(top, &s.mu, s.timeNowFunc, ps)
418	top.subs[ps.Name] = sub
419	s.subs[ps.Name] = sub
420	sub.start(&s.wg)
421	return ps, nil
422}
423
424// Can be set for testing.
425var minAckDeadlineSecs int32
426
427// SetMinAckDeadline changes the minack deadline to n. Must be
428// greater than or equal to 1 second. Remember to reset this value
429// to the default after your test changes it. Example usage:
430// 		pstest.SetMinAckDeadlineSecs(1)
431// 		defer pstest.ResetMinAckDeadlineSecs()
432func SetMinAckDeadline(n time.Duration) {
433	if n < time.Second {
434		panic("SetMinAckDeadline expects a value greater than 1 second")
435	}
436
437	minAckDeadlineSecs = int32(n / time.Second)
438}
439
440// ResetMinAckDeadline resets the minack deadline to the default.
441func ResetMinAckDeadline() {
442	minAckDeadlineSecs = 10
443}
444
445func checkAckDeadline(ads int32) error {
446	if ads < minAckDeadlineSecs || ads > 600 {
447		// PubSub service returns Unknown.
448		return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
449	}
450	return nil
451}
452
453const (
454	minMessageRetentionDuration = 10 * time.Minute
455	maxMessageRetentionDuration = 168 * time.Hour
456)
457
458var defaultMessageRetentionDuration = durpb.New(maxMessageRetentionDuration)
459
460func checkMRD(pmrd *durpb.Duration) error {
461	mrd := pmrd.AsDuration()
462	if mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
463		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
464	}
465	return nil
466}
467
468func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
469	s.mu.Lock()
470	defer s.mu.Unlock()
471
472	if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
473		return ret.(*pb.Subscription), err
474	}
475
476	sub, err := s.findSubscription(req.Subscription)
477	if err != nil {
478		return nil, err
479	}
480	return sub.proto, nil
481}
482
483func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
484	if req.Subscription == nil {
485		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
486	}
487	s.mu.Lock()
488	defer s.mu.Unlock()
489
490	if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
491		return ret.(*pb.Subscription), err
492	}
493
494	sub, err := s.findSubscription(req.Subscription.Name)
495	if err != nil {
496		return nil, err
497	}
498	for _, path := range req.UpdateMask.Paths {
499		switch path {
500		case "push_config":
501			sub.proto.PushConfig = req.Subscription.PushConfig
502
503		case "ack_deadline_seconds":
504			a := req.Subscription.AckDeadlineSeconds
505			if err := checkAckDeadline(a); err != nil {
506				return nil, err
507			}
508			sub.proto.AckDeadlineSeconds = a
509
510		case "retain_acked_messages":
511			sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
512
513		case "message_retention_duration":
514			if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
515				return nil, err
516			}
517			sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
518
519		case "labels":
520			sub.proto.Labels = req.Subscription.Labels
521
522		case "expiration_policy":
523			sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
524
525		case "dead_letter_policy":
526			sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy
527
528		case "retry_policy":
529			sub.proto.RetryPolicy = req.Subscription.RetryPolicy
530
531		case "filter":
532			sub.proto.Filter = req.Subscription.Filter
533
534		default:
535			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
536		}
537	}
538	return sub.proto, nil
539}
540
541func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
542	s.mu.Lock()
543	defer s.mu.Unlock()
544
545	if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
546		return ret.(*pb.ListSubscriptionsResponse), err
547	}
548
549	var names []string
550	for name := range s.subs {
551		if strings.HasPrefix(name, req.Project) {
552			names = append(names, name)
553		}
554	}
555	sort.Strings(names)
556	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
557	if err != nil {
558		return nil, err
559	}
560	res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
561	for i := from; i < to; i++ {
562		res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
563	}
564	return res, nil
565}
566
567func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
568	s.mu.Lock()
569	defer s.mu.Unlock()
570
571	if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
572		return ret.(*emptypb.Empty), err
573	}
574
575	sub, err := s.findSubscription(req.Subscription)
576	if err != nil {
577		return nil, err
578	}
579	sub.stop()
580	delete(s.subs, req.Subscription)
581	sub.topic.deleteSub(sub)
582	return &emptypb.Empty{}, nil
583}
584
585func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
586	s.mu.Lock()
587	defer s.mu.Unlock()
588
589	if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
590		return ret.(*pb.DetachSubscriptionResponse), err
591	}
592
593	sub, err := s.findSubscription(req.Subscription)
594	if err != nil {
595		return nil, err
596	}
597	sub.topic.deleteSub(sub)
598	return &pb.DetachSubscriptionResponse{}, nil
599}
600
601func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
602	s.mu.Lock()
603	defer s.mu.Unlock()
604
605	if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
606		return ret.(*pb.PublishResponse), err
607	}
608
609	if req.Topic == "" {
610		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
611	}
612	top := s.topics[req.Topic]
613	if top == nil {
614		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
615	}
616	var ids []string
617	for _, pm := range req.Messages {
618		id := fmt.Sprintf("m%d", s.nextID)
619		s.nextID++
620		pm.MessageId = id
621		pubTime := s.timeNowFunc()
622		tsPubTime := timestamppb.New(pubTime)
623		pm.PublishTime = tsPubTime
624		m := &Message{
625			ID:          id,
626			Data:        pm.Data,
627			Attributes:  pm.Attributes,
628			PublishTime: pubTime,
629			OrderingKey: pm.OrderingKey,
630		}
631		top.publish(pm, m)
632		ids = append(ids, id)
633		s.msgs = append(s.msgs, m)
634		s.msgsByID[id] = m
635	}
636	return &pb.PublishResponse{MessageIds: ids}, nil
637}
638
639type topic struct {
640	proto *pb.Topic
641	subs  map[string]*subscription
642}
643
644func newTopic(pt *pb.Topic) *topic {
645	return &topic{
646		proto: pt,
647		subs:  map[string]*subscription{},
648	}
649}
650
651func (t *topic) stop() {
652	for _, sub := range t.subs {
653		sub.proto.Topic = "_deleted-topic_"
654	}
655}
656
657func (t *topic) deleteSub(sub *subscription) {
658	delete(t.subs, sub.proto.Name)
659}
660
661func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
662	for _, s := range t.subs {
663		s.msgs[pm.MessageId] = &message{
664			publishTime: m.PublishTime,
665			proto: &pb.ReceivedMessage{
666				AckId:   pm.MessageId,
667				Message: pm,
668			},
669			deliveries:  &m.deliveries,
670			acks:        &m.acks,
671			streamIndex: -1,
672		}
673	}
674}
675
676type subscription struct {
677	topic       *topic
678	mu          *sync.Mutex // the server mutex, here for convenience
679	proto       *pb.Subscription
680	ackTimeout  time.Duration
681	msgs        map[string]*message // unacked messages by message ID
682	streams     []*stream
683	done        chan struct{}
684	timeNowFunc func() time.Time
685}
686
687func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, ps *pb.Subscription) *subscription {
688	at := time.Duration(ps.AckDeadlineSeconds) * time.Second
689	if at == 0 {
690		at = 10 * time.Second
691	}
692	return &subscription{
693		topic:       t,
694		mu:          mu,
695		proto:       ps,
696		ackTimeout:  at,
697		msgs:        map[string]*message{},
698		done:        make(chan struct{}),
699		timeNowFunc: timeNowFunc,
700	}
701}
702
703func (s *subscription) start(wg *sync.WaitGroup) {
704	wg.Add(1)
705	go func() {
706		defer wg.Done()
707		for {
708			select {
709			case <-s.done:
710				return
711			case <-time.After(10 * time.Millisecond):
712				s.deliver()
713			}
714		}
715	}()
716}
717
718func (s *subscription) stop() {
719	close(s.done)
720}
721
722func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
723	s.mu.Lock()
724	defer s.mu.Unlock()
725
726	if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
727		return ret.(*emptypb.Empty), err
728	}
729
730	sub, err := s.findSubscription(req.Subscription)
731	if err != nil {
732		return nil, err
733	}
734	for _, id := range req.AckIds {
735		sub.ack(id)
736	}
737	return &emptypb.Empty{}, nil
738}
739
740func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
741	s.mu.Lock()
742	defer s.mu.Unlock()
743
744	if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
745		return ret.(*emptypb.Empty), err
746	}
747
748	sub, err := s.findSubscription(req.Subscription)
749	if err != nil {
750		return nil, err
751	}
752	now := time.Now()
753	for _, id := range req.AckIds {
754		s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
755	}
756	dur := secsToDur(req.AckDeadlineSeconds)
757	for _, id := range req.AckIds {
758		sub.modifyAckDeadline(id, dur)
759	}
760	return &emptypb.Empty{}, nil
761}
762
763func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
764	s.mu.Lock()
765
766	if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
767		s.mu.Unlock()
768		return ret.(*pb.PullResponse), err
769	}
770
771	sub, err := s.findSubscription(req.Subscription)
772	if err != nil {
773		s.mu.Unlock()
774		return nil, err
775	}
776	max := int(req.MaxMessages)
777	if max < 0 {
778		s.mu.Unlock()
779		return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
780	}
781	if max == 0 { // MaxMessages not specified; use a default.
782		max = 1000
783	}
784	msgs := sub.pull(max)
785	s.mu.Unlock()
786	// Implement the spec from the pubsub proto:
787	// "If ReturnImmediately set to true, the system will respond immediately even if
788	// it there are no messages available to return in the `Pull` response.
789	// Otherwise, the system may wait (for a bounded amount of time) until at
790	// least one message is available, rather than returning no messages."
791	if len(msgs) == 0 && !req.ReturnImmediately {
792		// Wait for a short amount of time for a message.
793		// TODO: signal when a message arrives, so we don't wait the whole time.
794		select {
795		case <-ctx.Done():
796			return nil, ctx.Err()
797		case <-time.After(500 * time.Millisecond):
798			s.mu.Lock()
799			msgs = sub.pull(max)
800			s.mu.Unlock()
801		}
802	}
803	return &pb.PullResponse{ReceivedMessages: msgs}, nil
804}
805
806func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
807	// Receive initial message configuring the pull.
808	req, err := sps.Recv()
809	if err != nil {
810		return err
811	}
812	s.mu.Lock()
813	sub, err := s.findSubscription(req.Subscription)
814	s.mu.Unlock()
815	if err != nil {
816		return err
817	}
818	// Create a new stream to handle the pull.
819	st := sub.newStream(sps, s.streamTimeout)
820	err = st.pull(&s.wg)
821	sub.deleteStream(st)
822	return err
823}
824
825func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
826	// Only handle time-based seeking for now.
827	// This fake doesn't deal with snapshots.
828	var target time.Time
829	switch v := req.Target.(type) {
830	case nil:
831		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
832	case *pb.SeekRequest_Time:
833		target = v.Time.AsTime()
834	default:
835		return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
836	}
837
838	// The entire server must be locked while doing the work below,
839	// because the messages don't have any other synchronization.
840	s.mu.Lock()
841	defer s.mu.Unlock()
842
843	if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
844		return ret.(*pb.SeekResponse), err
845	}
846
847	sub, err := s.findSubscription(req.Subscription)
848	if err != nil {
849		return nil, err
850	}
851	// Drop all messages from sub that were published before the target time.
852	for id, m := range sub.msgs {
853		if m.publishTime.Before(target) {
854			delete(sub.msgs, id)
855			(*m.acks)++
856		}
857	}
858	// Un-ack any already-acked messages after this time;
859	// redelivering them to the subscription is the closest analogue here.
860	for _, m := range s.msgs {
861		if m.PublishTime.Before(target) {
862			continue
863		}
864		sub.msgs[m.ID] = &message{
865			publishTime: m.PublishTime,
866			proto: &pb.ReceivedMessage{
867				AckId: m.ID,
868				// This was not preserved!
869				//Message: pm,
870			},
871			deliveries:  &m.deliveries,
872			acks:        &m.acks,
873			streamIndex: -1,
874		}
875	}
876	return &pb.SeekResponse{}, nil
877}
878
879// Gets a subscription that must exist.
880// Must be called with the lock held.
881func (s *GServer) findSubscription(name string) (*subscription, error) {
882	if name == "" {
883		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
884	}
885	sub := s.subs[name]
886	if sub == nil {
887		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
888	}
889	return sub, nil
890}
891
892// Must be called with the lock held.
893func (s *subscription) pull(max int) []*pb.ReceivedMessage {
894	now := s.timeNowFunc()
895	s.maintainMessages(now)
896	var msgs []*pb.ReceivedMessage
897	for _, m := range s.msgs {
898		if m.outstanding() {
899			continue
900		}
901		(*m.deliveries)++
902		m.ackDeadline = now.Add(s.ackTimeout)
903		msgs = append(msgs, m.proto)
904		if len(msgs) >= max {
905			break
906		}
907	}
908	return msgs
909}
910
911func (s *subscription) deliver() {
912	s.mu.Lock()
913	defer s.mu.Unlock()
914
915	now := s.timeNowFunc()
916	s.maintainMessages(now)
917	// Try to deliver each remaining message.
918	curIndex := 0
919	for _, m := range s.msgs {
920		if m.outstanding() {
921			continue
922		}
923		// If the message was never delivered before, start with the stream at
924		// curIndex. If it was delivered before, start with the stream after the one
925		// that owned it.
926		if m.streamIndex < 0 {
927			delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
928			if !ok {
929				break
930			}
931			curIndex = delIndex + 1
932			m.streamIndex = curIndex
933		} else {
934			delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
935			if !ok {
936				break
937			}
938			m.streamIndex = delIndex
939		}
940	}
941}
942
943// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
944// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
945// exits.
946//
947// It returns the index of the stream it delivered the message to, or 0, false if
948// it didn't deliver the message.
949//
950// Must be called with the lock held.
951func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
952	for i := 0; i < len(s.streams); i++ {
953		idx := (i + start) % len(s.streams)
954
955		st := s.streams[idx]
956		select {
957		case <-st.done:
958			s.streams = deleteStreamAt(s.streams, idx)
959			i--
960
961		case st.msgc <- m.proto:
962			(*m.deliveries)++
963			m.ackDeadline = now.Add(st.ackTimeout)
964			return idx, true
965
966		default:
967		}
968	}
969	return 0, false
970}
971
972var retentionDuration = 10 * time.Minute
973
974// Must be called with the lock held.
975func (s *subscription) maintainMessages(now time.Time) {
976	for id, m := range s.msgs {
977		// Mark a message as re-deliverable if its ack deadline has expired.
978		if m.outstanding() && now.After(m.ackDeadline) {
979			m.makeAvailable()
980		}
981		pubTime := m.proto.Message.PublishTime.AsTime()
982		// Remove messages that have been undelivered for a long time.
983		if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
984			delete(s.msgs, id)
985		}
986	}
987}
988
989func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
990	st := &stream{
991		sub:        s,
992		done:       make(chan struct{}),
993		msgc:       make(chan *pb.ReceivedMessage),
994		gstream:    gs,
995		ackTimeout: s.ackTimeout,
996		timeout:    timeout,
997	}
998	s.mu.Lock()
999	s.streams = append(s.streams, st)
1000	s.mu.Unlock()
1001	return st
1002}
1003
1004func (s *subscription) deleteStream(st *stream) {
1005	s.mu.Lock()
1006	defer s.mu.Unlock()
1007	var i int
1008	for i = 0; i < len(s.streams); i++ {
1009		if s.streams[i] == st {
1010			break
1011		}
1012	}
1013	if i < len(s.streams) {
1014		s.streams = deleteStreamAt(s.streams, i)
1015	}
1016}
1017func deleteStreamAt(s []*stream, i int) []*stream {
1018	// Preserve order for round-robin delivery.
1019	return append(s[:i], s[i+1:]...)
1020}
1021
1022type message struct {
1023	proto       *pb.ReceivedMessage
1024	publishTime time.Time
1025	ackDeadline time.Time
1026	deliveries  *int
1027	acks        *int
1028	streamIndex int // index of stream that currently owns msg, for round-robin delivery
1029}
1030
1031// A message is outstanding if it is owned by some stream.
1032func (m *message) outstanding() bool {
1033	return !m.ackDeadline.IsZero()
1034}
1035
1036func (m *message) makeAvailable() {
1037	m.ackDeadline = time.Time{}
1038}
1039
1040type stream struct {
1041	sub        *subscription
1042	done       chan struct{} // closed when the stream is finished
1043	msgc       chan *pb.ReceivedMessage
1044	gstream    pb.Subscriber_StreamingPullServer
1045	ackTimeout time.Duration
1046	timeout    time.Duration
1047}
1048
1049// pull manages the StreamingPull interaction for the life of the stream.
1050func (st *stream) pull(wg *sync.WaitGroup) error {
1051	errc := make(chan error, 2)
1052	wg.Add(2)
1053	go func() {
1054		defer wg.Done()
1055		errc <- st.sendLoop()
1056	}()
1057	go func() {
1058		defer wg.Done()
1059		errc <- st.recvLoop()
1060	}()
1061	var tchan <-chan time.Time
1062	if st.timeout > 0 {
1063		tchan = time.After(st.timeout)
1064	}
1065	// Wait until one of the goroutines returns an error, or we time out.
1066	var err error
1067	select {
1068	case err = <-errc:
1069		if err == io.EOF {
1070			err = nil
1071		}
1072	case <-tchan:
1073	}
1074	close(st.done) // stop the other goroutine
1075	return err
1076}
1077
1078func (st *stream) sendLoop() error {
1079	for {
1080		select {
1081		case <-st.done:
1082			return nil
1083		case rm := <-st.msgc:
1084			res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
1085			if err := st.gstream.Send(res); err != nil {
1086				return err
1087			}
1088		}
1089	}
1090}
1091
1092func (st *stream) recvLoop() error {
1093	for {
1094		req, err := st.gstream.Recv()
1095		if err != nil {
1096			return err
1097		}
1098		st.sub.handleStreamingPullRequest(st, req)
1099	}
1100}
1101
1102func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
1103	// Lock the entire server.
1104	s.mu.Lock()
1105	defer s.mu.Unlock()
1106
1107	for _, ackID := range req.AckIds {
1108		s.ack(ackID)
1109	}
1110	for i, id := range req.ModifyDeadlineAckIds {
1111		s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
1112	}
1113	if req.StreamAckDeadlineSeconds > 0 {
1114		st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
1115	}
1116}
1117
1118// Must be called with the lock held.
1119func (s *subscription) ack(id string) {
1120	m := s.msgs[id]
1121	if m != nil {
1122		(*m.acks)++
1123		delete(s.msgs, id)
1124	}
1125}
1126
1127// Must be called with the lock held.
1128func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
1129	m := s.msgs[id]
1130	if m == nil { // already acked: ignore.
1131		return
1132	}
1133	if d == 0 { // nack
1134		m.makeAvailable()
1135	} else { // extend the deadline by d
1136		m.ackDeadline = s.timeNowFunc().Add(d)
1137	}
1138}
1139
1140func secsToDur(secs int32) time.Duration {
1141	return time.Duration(secs) * time.Second
1142}
1143
1144// runReactor looks up the reactors for a function, then launches them until handled=true
1145// or err is returned. If the reactor returns nil, the function returns defaultObj instead.
1146func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
1147	if val, ok := s.reactorOptions[funcName]; ok {
1148		for _, reactor := range val {
1149			handled, ret, err := reactor.React(req)
1150			// If handled=true, that means the reactor has successfully reacted to the request,
1151			// so use the output directly. If err occurs, that means the request is invalidated
1152			// by the reactor somehow.
1153			if handled || err != nil {
1154				if ret == nil {
1155					ret = defaultObj
1156				}
1157				return true, ret, err
1158			}
1159		}
1160	}
1161	return false, nil, nil
1162}
1163
1164// errorInjectionReactor is a reactor to inject an error message with status code.
1165type errorInjectionReactor struct {
1166	code codes.Code
1167	msg  string
1168}
1169
1170// React simply returns an error with defined error message and status code.
1171func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
1172	return true, nil, status.Errorf(e.code, e.msg)
1173}
1174
1175// WithErrorInjection creates a ServerReactorOption that injects error with defined status code and
1176// message for a certain function.
1177func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption {
1178	return ServerReactorOption{
1179		FuncName: funcName,
1180		Reactor:  &errorInjectionReactor{code: code, msg: msg},
1181	}
1182}
1183