1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package pstest provides a fake Cloud PubSub service for testing. It implements a
16// simplified form of the service, suitable for unit tests. It may behave
17// differently from the actual service in ways in which the service is
18// non-deterministic or unspecified: timing, delivery order, etc.
19//
20// This package is EXPERIMENTAL and is subject to change without notice.
21//
22// See the example for usage.
23package pstest
24
25import (
26	"context"
27	"fmt"
28	"io"
29	"path"
30	"sort"
31	"strings"
32	"sync"
33	"sync/atomic"
34	"time"
35
36	"cloud.google.com/go/internal/testutil"
37	"github.com/golang/protobuf/ptypes"
38	durpb "github.com/golang/protobuf/ptypes/duration"
39	emptypb "github.com/golang/protobuf/ptypes/empty"
40	pb "google.golang.org/genproto/googleapis/pubsub/v1"
41	"google.golang.org/grpc/codes"
42	"google.golang.org/grpc/status"
43)
44
45// For testing. Note that even though changes to the now variable are atomic, a call
46// to the stored function can race with a change to that function. This could be a
47// problem if tests are run in parallel, or even if concurrent parts of the same test
48// change the value of the variable.
49var now atomic.Value
50
51func init() {
52	now.Store(time.Now)
53	ResetMinAckDeadline()
54}
55
56func timeNow() time.Time {
57	return now.Load().(func() time.Time)()
58}
59
60// Server is a fake Pub/Sub server.
61type Server struct {
62	srv     *testutil.Server
63	Addr    string  // The address that the server is listening on.
64	GServer GServer // Not intended to be used directly.
65}
66
67// GServer is the underlying service implementor. It is not intended to be used
68// directly.
69type GServer struct {
70	pb.PublisherServer
71	pb.SubscriberServer
72
73	mu            sync.Mutex
74	topics        map[string]*topic
75	subs          map[string]*subscription
76	msgs          []*Message // all messages ever published
77	msgsByID      map[string]*Message
78	wg            sync.WaitGroup
79	nextID        int
80	streamTimeout time.Duration
81}
82
83// NewServer creates a new fake server running in the current process.
84func NewServer() *Server {
85	srv, err := testutil.NewServer()
86	if err != nil {
87		panic(fmt.Sprintf("pstest.NewServer: %v", err))
88	}
89	s := &Server{
90		srv:  srv,
91		Addr: srv.Addr,
92		GServer: GServer{
93			topics:   map[string]*topic{},
94			subs:     map[string]*subscription{},
95			msgsByID: map[string]*Message{},
96		},
97	}
98	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
99	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
100	srv.Start()
101	return s
102}
103
104// Publish behaves as if the Publish RPC was called with a message with the given
105// data and attrs. It returns the ID of the message.
106// The topic will be created if it doesn't exist.
107//
108// Publish panics if there is an error, which is appropriate for testing.
109func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
110	const topicPattern = "projects/*/topics/*"
111	ok, err := path.Match(topicPattern, topic)
112	if err != nil {
113		panic(err)
114	}
115	if !ok {
116		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
117	}
118	_, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
119	req := &pb.PublishRequest{
120		Topic:    topic,
121		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
122	}
123	res, err := s.GServer.Publish(context.TODO(), req)
124	if err != nil {
125		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
126	}
127	return res.MessageIds[0]
128}
129
130// SetStreamTimeout sets the amount of time a stream will be active before it shuts
131// itself down. This mimics the real service's behavior of closing streams after 30
132// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
133// down.
134func (s *Server) SetStreamTimeout(d time.Duration) {
135	s.GServer.mu.Lock()
136	defer s.GServer.mu.Unlock()
137	s.GServer.streamTimeout = d
138}
139
140// A Message is a message that was published to the server.
141type Message struct {
142	ID          string
143	Data        []byte
144	Attributes  map[string]string
145	PublishTime time.Time
146	Deliveries  int // number of times delivery of the message was attempted
147	Acks        int // number of acks received from clients
148
149	// protected by server mutex
150	deliveries int
151	acks       int
152	Modacks    []Modack // modacks received by server for this message
153
154}
155
156// Modack represents a modack sent to the server.
157type Modack struct {
158	AckID       string
159	AckDeadline int32
160	ReceivedAt  time.Time
161}
162
163// Messages returns information about all messages ever published.
164func (s *Server) Messages() []*Message {
165	s.GServer.mu.Lock()
166	defer s.GServer.mu.Unlock()
167
168	var msgs []*Message
169	for _, m := range s.GServer.msgs {
170		m.Deliveries = m.deliveries
171		m.Acks = m.acks
172		msgs = append(msgs, m)
173	}
174	return msgs
175}
176
177// Message returns the message with the given ID, or nil if no message
178// with that ID was published.
179func (s *Server) Message(id string) *Message {
180	s.GServer.mu.Lock()
181	defer s.GServer.mu.Unlock()
182
183	m := s.GServer.msgsByID[id]
184	if m != nil {
185		m.Deliveries = m.deliveries
186		m.Acks = m.acks
187	}
188	return m
189}
190
191// Wait blocks until all server activity has completed.
192func (s *Server) Wait() {
193	s.GServer.wg.Wait()
194}
195
196// ClearMessages removes all published messages
197// from internal containers.
198func (s *Server) ClearMessages() {
199	s.GServer.mu.Lock()
200	s.GServer.msgs = nil
201	s.GServer.msgsByID = make(map[string]*Message)
202	s.GServer.mu.Unlock()
203}
204
205// Close shuts down the server and releases all resources.
206func (s *Server) Close() error {
207	s.srv.Close()
208	s.GServer.mu.Lock()
209	defer s.GServer.mu.Unlock()
210	for _, sub := range s.GServer.subs {
211		sub.stop()
212	}
213	return nil
214}
215
216func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
217	s.mu.Lock()
218	defer s.mu.Unlock()
219
220	if s.topics[t.Name] != nil {
221		return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
222	}
223	top := newTopic(t)
224	s.topics[t.Name] = top
225	return top.proto, nil
226}
227
228func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
229	s.mu.Lock()
230	defer s.mu.Unlock()
231
232	if t := s.topics[req.Topic]; t != nil {
233		return t.proto, nil
234	}
235	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
236}
237
238func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
239	s.mu.Lock()
240	defer s.mu.Unlock()
241
242	t := s.topics[req.Topic.Name]
243	if t == nil {
244		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
245	}
246	for _, path := range req.UpdateMask.Paths {
247		switch path {
248		case "labels":
249			t.proto.Labels = req.Topic.Labels
250		case "message_storage_policy":
251			t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy
252		default:
253			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
254		}
255	}
256	return t.proto, nil
257}
258
259func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
260	s.mu.Lock()
261	defer s.mu.Unlock()
262
263	var names []string
264	for n := range s.topics {
265		if strings.HasPrefix(n, req.Project) {
266			names = append(names, n)
267		}
268	}
269	sort.Strings(names)
270	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
271	if err != nil {
272		return nil, err
273	}
274	res := &pb.ListTopicsResponse{NextPageToken: nextToken}
275	for i := from; i < to; i++ {
276		res.Topics = append(res.Topics, s.topics[names[i]].proto)
277	}
278	return res, nil
279}
280
281func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
282	s.mu.Lock()
283	defer s.mu.Unlock()
284
285	var names []string
286	for name, sub := range s.subs {
287		if sub.topic.proto.Name == req.Topic {
288			names = append(names, name)
289		}
290	}
291	sort.Strings(names)
292	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
293	if err != nil {
294		return nil, err
295	}
296	return &pb.ListTopicSubscriptionsResponse{
297		Subscriptions: names[from:to],
298		NextPageToken: nextToken,
299	}, nil
300}
301
302func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
303	s.mu.Lock()
304	defer s.mu.Unlock()
305
306	t := s.topics[req.Topic]
307	if t == nil {
308		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
309	}
310	t.stop()
311	delete(s.topics, req.Topic)
312	return &emptypb.Empty{}, nil
313}
314
315func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
316	s.mu.Lock()
317	defer s.mu.Unlock()
318
319	if ps.Name == "" {
320		return nil, status.Errorf(codes.InvalidArgument, "missing name")
321	}
322	if s.subs[ps.Name] != nil {
323		return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
324	}
325	if ps.Topic == "" {
326		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
327	}
328	top := s.topics[ps.Topic]
329	if top == nil {
330		return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
331	}
332	if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
333		return nil, err
334	}
335	if ps.MessageRetentionDuration == nil {
336		ps.MessageRetentionDuration = defaultMessageRetentionDuration
337	}
338	if err := checkMRD(ps.MessageRetentionDuration); err != nil {
339		return nil, err
340	}
341	if ps.PushConfig == nil {
342		ps.PushConfig = &pb.PushConfig{}
343	}
344
345	sub := newSubscription(top, &s.mu, ps)
346	top.subs[ps.Name] = sub
347	s.subs[ps.Name] = sub
348	sub.start(&s.wg)
349	return ps, nil
350}
351
352// Can be set for testing.
353var minAckDeadlineSecs int32
354
355// SetMinAckDeadline changes the minack deadline to n. Must be
356// greater than or equal to 1 second. Remember to reset this value
357// to the default after your test changes it. Example usage:
358// 		pstest.SetMinAckDeadlineSecs(1)
359// 		defer pstest.ResetMinAckDeadlineSecs()
360func SetMinAckDeadline(n time.Duration) {
361	if n < time.Second {
362		panic("SetMinAckDeadline expects a value greater than 1 second")
363	}
364
365	minAckDeadlineSecs = int32(n / time.Second)
366}
367
368// ResetMinAckDeadline resets the minack deadline to the default.
369func ResetMinAckDeadline() {
370	minAckDeadlineSecs = 10
371}
372
373func checkAckDeadline(ads int32) error {
374	if ads < minAckDeadlineSecs || ads > 600 {
375		// PubSub service returns Unknown.
376		return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
377	}
378	return nil
379}
380
381const (
382	minMessageRetentionDuration = 10 * time.Minute
383	maxMessageRetentionDuration = 168 * time.Hour
384)
385
386var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
387
388func checkMRD(pmrd *durpb.Duration) error {
389	mrd, err := ptypes.Duration(pmrd)
390	if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
391		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
392	}
393	return nil
394}
395
396func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
397	s.mu.Lock()
398	defer s.mu.Unlock()
399	sub, err := s.findSubscription(req.Subscription)
400	if err != nil {
401		return nil, err
402	}
403	return sub.proto, nil
404}
405
406func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
407	if req.Subscription == nil {
408		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
409	}
410	s.mu.Lock()
411	defer s.mu.Unlock()
412	sub, err := s.findSubscription(req.Subscription.Name)
413	if err != nil {
414		return nil, err
415	}
416	for _, path := range req.UpdateMask.Paths {
417		switch path {
418		case "push_config":
419			sub.proto.PushConfig = req.Subscription.PushConfig
420
421		case "ack_deadline_seconds":
422			a := req.Subscription.AckDeadlineSeconds
423			if err := checkAckDeadline(a); err != nil {
424				return nil, err
425			}
426			sub.proto.AckDeadlineSeconds = a
427
428		case "retain_acked_messages":
429			sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
430
431		case "message_retention_duration":
432			if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
433				return nil, err
434			}
435			sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
436
437		case "labels":
438			sub.proto.Labels = req.Subscription.Labels
439
440		case "expiration_policy":
441			sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
442
443		default:
444			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
445		}
446	}
447	return sub.proto, nil
448}
449
450func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
451	s.mu.Lock()
452	defer s.mu.Unlock()
453
454	var names []string
455	for name := range s.subs {
456		if strings.HasPrefix(name, req.Project) {
457			names = append(names, name)
458		}
459	}
460	sort.Strings(names)
461	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
462	if err != nil {
463		return nil, err
464	}
465	res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
466	for i := from; i < to; i++ {
467		res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
468	}
469	return res, nil
470}
471
472func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
473	s.mu.Lock()
474	defer s.mu.Unlock()
475	sub, err := s.findSubscription(req.Subscription)
476	if err != nil {
477		return nil, err
478	}
479	sub.stop()
480	delete(s.subs, req.Subscription)
481	sub.topic.deleteSub(sub)
482	return &emptypb.Empty{}, nil
483}
484
485func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
486	s.mu.Lock()
487	defer s.mu.Unlock()
488
489	if req.Topic == "" {
490		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
491	}
492	top := s.topics[req.Topic]
493	if top == nil {
494		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
495	}
496	var ids []string
497	for _, pm := range req.Messages {
498		id := fmt.Sprintf("m%d", s.nextID)
499		s.nextID++
500		pm.MessageId = id
501		pubTime := timeNow()
502		tsPubTime, err := ptypes.TimestampProto(pubTime)
503		if err != nil {
504			return nil, status.Errorf(codes.Internal, err.Error())
505		}
506		pm.PublishTime = tsPubTime
507		m := &Message{
508			ID:          id,
509			Data:        pm.Data,
510			Attributes:  pm.Attributes,
511			PublishTime: pubTime,
512		}
513		top.publish(pm, m)
514		ids = append(ids, id)
515		s.msgs = append(s.msgs, m)
516		s.msgsByID[id] = m
517	}
518	return &pb.PublishResponse{MessageIds: ids}, nil
519}
520
521type topic struct {
522	proto *pb.Topic
523	subs  map[string]*subscription
524}
525
526func newTopic(pt *pb.Topic) *topic {
527	return &topic{
528		proto: pt,
529		subs:  map[string]*subscription{},
530	}
531}
532
533func (t *topic) stop() {
534	for _, sub := range t.subs {
535		sub.proto.Topic = "_deleted-topic_"
536		sub.stop()
537	}
538}
539
540func (t *topic) deleteSub(sub *subscription) {
541	delete(t.subs, sub.proto.Name)
542}
543
544func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
545	for _, s := range t.subs {
546		s.msgs[pm.MessageId] = &message{
547			publishTime: m.PublishTime,
548			proto: &pb.ReceivedMessage{
549				AckId:   pm.MessageId,
550				Message: pm,
551			},
552			deliveries:  &m.deliveries,
553			acks:        &m.acks,
554			streamIndex: -1,
555		}
556	}
557}
558
559type subscription struct {
560	topic      *topic
561	mu         *sync.Mutex // the server mutex, here for convenience
562	proto      *pb.Subscription
563	ackTimeout time.Duration
564	msgs       map[string]*message // unacked messages by message ID
565	streams    []*stream
566	done       chan struct{}
567}
568
569func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
570	at := time.Duration(ps.AckDeadlineSeconds) * time.Second
571	if at == 0 {
572		at = 10 * time.Second
573	}
574	return &subscription{
575		topic:      t,
576		mu:         mu,
577		proto:      ps,
578		ackTimeout: at,
579		msgs:       map[string]*message{},
580		done:       make(chan struct{}),
581	}
582}
583
584func (s *subscription) start(wg *sync.WaitGroup) {
585	wg.Add(1)
586	go func() {
587		defer wg.Done()
588		for {
589			select {
590			case <-s.done:
591				return
592			case <-time.After(10 * time.Millisecond):
593				s.deliver()
594			}
595		}
596	}()
597}
598
599func (s *subscription) stop() {
600	close(s.done)
601}
602
603func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
604	s.mu.Lock()
605	defer s.mu.Unlock()
606
607	sub, err := s.findSubscription(req.Subscription)
608	if err != nil {
609		return nil, err
610	}
611	for _, id := range req.AckIds {
612		sub.ack(id)
613	}
614	return &emptypb.Empty{}, nil
615}
616
617func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
618	s.mu.Lock()
619	defer s.mu.Unlock()
620	sub, err := s.findSubscription(req.Subscription)
621	if err != nil {
622		return nil, err
623	}
624	now := time.Now()
625	for _, id := range req.AckIds {
626		s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
627	}
628	dur := secsToDur(req.AckDeadlineSeconds)
629	for _, id := range req.AckIds {
630		sub.modifyAckDeadline(id, dur)
631	}
632	return &emptypb.Empty{}, nil
633}
634
635func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
636	s.mu.Lock()
637	sub, err := s.findSubscription(req.Subscription)
638	if err != nil {
639		s.mu.Unlock()
640		return nil, err
641	}
642	max := int(req.MaxMessages)
643	if max < 0 {
644		s.mu.Unlock()
645		return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
646	}
647	if max == 0 { // MaxMessages not specified; use a default.
648		max = 1000
649	}
650	msgs := sub.pull(max)
651	s.mu.Unlock()
652	// Implement the spec from the pubsub proto:
653	// "If ReturnImmediately set to true, the system will respond immediately even if
654	// it there are no messages available to return in the `Pull` response.
655	// Otherwise, the system may wait (for a bounded amount of time) until at
656	// least one message is available, rather than returning no messages."
657	if len(msgs) == 0 && !req.ReturnImmediately {
658		// Wait for a short amount of time for a message.
659		// TODO: signal when a message arrives, so we don't wait the whole time.
660		select {
661		case <-ctx.Done():
662			return nil, ctx.Err()
663		case <-time.After(500 * time.Millisecond):
664			s.mu.Lock()
665			msgs = sub.pull(max)
666			s.mu.Unlock()
667		}
668	}
669	return &pb.PullResponse{ReceivedMessages: msgs}, nil
670}
671
672func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
673	// Receive initial message configuring the pull.
674	req, err := sps.Recv()
675	if err != nil {
676		return err
677	}
678	s.mu.Lock()
679	sub, err := s.findSubscription(req.Subscription)
680	s.mu.Unlock()
681	if err != nil {
682		return err
683	}
684	// Create a new stream to handle the pull.
685	st := sub.newStream(sps, s.streamTimeout)
686	err = st.pull(&s.wg)
687	sub.deleteStream(st)
688	return err
689}
690
691func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
692	// Only handle time-based seeking for now.
693	// This fake doesn't deal with snapshots.
694	var target time.Time
695	switch v := req.Target.(type) {
696	case nil:
697		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
698	case *pb.SeekRequest_Time:
699		var err error
700		target, err = ptypes.Timestamp(v.Time)
701		if err != nil {
702			return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err)
703		}
704	default:
705		return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
706	}
707
708	// The entire server must be locked while doing the work below,
709	// because the messages don't have any other synchronization.
710	s.mu.Lock()
711	defer s.mu.Unlock()
712	sub, err := s.findSubscription(req.Subscription)
713	if err != nil {
714		return nil, err
715	}
716	// Drop all messages from sub that were published before the target time.
717	for id, m := range sub.msgs {
718		if m.publishTime.Before(target) {
719			delete(sub.msgs, id)
720			(*m.acks)++
721		}
722	}
723	// Un-ack any already-acked messages after this time;
724	// redelivering them to the subscription is the closest analogue here.
725	for _, m := range s.msgs {
726		if m.PublishTime.Before(target) {
727			continue
728		}
729		sub.msgs[m.ID] = &message{
730			publishTime: m.PublishTime,
731			proto: &pb.ReceivedMessage{
732				AckId: m.ID,
733				// This was not preserved!
734				//Message: pm,
735			},
736			deliveries:  &m.deliveries,
737			acks:        &m.acks,
738			streamIndex: -1,
739		}
740	}
741	return &pb.SeekResponse{}, nil
742}
743
744// Gets a subscription that must exist.
745// Must be called with the lock held.
746func (s *GServer) findSubscription(name string) (*subscription, error) {
747	if name == "" {
748		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
749	}
750	sub := s.subs[name]
751	if sub == nil {
752		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
753	}
754	return sub, nil
755}
756
757// Must be called with the lock held.
758func (s *subscription) pull(max int) []*pb.ReceivedMessage {
759	now := timeNow()
760	s.maintainMessages(now)
761	var msgs []*pb.ReceivedMessage
762	for _, m := range s.msgs {
763		if m.outstanding() {
764			continue
765		}
766		(*m.deliveries)++
767		m.ackDeadline = now.Add(s.ackTimeout)
768		msgs = append(msgs, m.proto)
769		if len(msgs) >= max {
770			break
771		}
772	}
773	return msgs
774}
775
776func (s *subscription) deliver() {
777	s.mu.Lock()
778	defer s.mu.Unlock()
779
780	now := timeNow()
781	s.maintainMessages(now)
782	// Try to deliver each remaining message.
783	curIndex := 0
784	for _, m := range s.msgs {
785		if m.outstanding() {
786			continue
787		}
788		// If the message was never delivered before, start with the stream at
789		// curIndex. If it was delivered before, start with the stream after the one
790		// that owned it.
791		if m.streamIndex < 0 {
792			delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
793			if !ok {
794				break
795			}
796			curIndex = delIndex + 1
797			m.streamIndex = curIndex
798		} else {
799			delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
800			if !ok {
801				break
802			}
803			m.streamIndex = delIndex
804		}
805	}
806}
807
808// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
809// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
810// exits.
811//
812// It returns the index of the stream it delivered the message to, or 0, false if
813// it didn't deliver the message.
814//
815// Must be called with the lock held.
816func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
817	for i := 0; i < len(s.streams); i++ {
818		idx := (i + start) % len(s.streams)
819
820		st := s.streams[idx]
821		select {
822		case <-st.done:
823			s.streams = deleteStreamAt(s.streams, idx)
824			i--
825
826		case st.msgc <- m.proto:
827			(*m.deliveries)++
828			m.ackDeadline = now.Add(st.ackTimeout)
829			return idx, true
830
831		default:
832		}
833	}
834	return 0, false
835}
836
837var retentionDuration = 10 * time.Minute
838
839// Must be called with the lock held.
840func (s *subscription) maintainMessages(now time.Time) {
841	for id, m := range s.msgs {
842		// Mark a message as re-deliverable if its ack deadline has expired.
843		if m.outstanding() && now.After(m.ackDeadline) {
844			m.makeAvailable()
845		}
846		pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
847		if err != nil {
848			panic(err)
849		}
850		// Remove messages that have been undelivered for a long time.
851		if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
852			delete(s.msgs, id)
853		}
854	}
855}
856
857func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
858	st := &stream{
859		sub:        s,
860		done:       make(chan struct{}),
861		msgc:       make(chan *pb.ReceivedMessage),
862		gstream:    gs,
863		ackTimeout: s.ackTimeout,
864		timeout:    timeout,
865	}
866	s.mu.Lock()
867	s.streams = append(s.streams, st)
868	s.mu.Unlock()
869	return st
870}
871
872func (s *subscription) deleteStream(st *stream) {
873	s.mu.Lock()
874	defer s.mu.Unlock()
875	var i int
876	for i = 0; i < len(s.streams); i++ {
877		if s.streams[i] == st {
878			break
879		}
880	}
881	if i < len(s.streams) {
882		s.streams = deleteStreamAt(s.streams, i)
883	}
884}
885func deleteStreamAt(s []*stream, i int) []*stream {
886	// Preserve order for round-robin delivery.
887	return append(s[:i], s[i+1:]...)
888}
889
890type message struct {
891	proto       *pb.ReceivedMessage
892	publishTime time.Time
893	ackDeadline time.Time
894	deliveries  *int
895	acks        *int
896	streamIndex int // index of stream that currently owns msg, for round-robin delivery
897}
898
899// A message is outstanding if it is owned by some stream.
900func (m *message) outstanding() bool {
901	return !m.ackDeadline.IsZero()
902}
903
904func (m *message) makeAvailable() {
905	m.ackDeadline = time.Time{}
906}
907
908type stream struct {
909	sub        *subscription
910	done       chan struct{} // closed when the stream is finished
911	msgc       chan *pb.ReceivedMessage
912	gstream    pb.Subscriber_StreamingPullServer
913	ackTimeout time.Duration
914	timeout    time.Duration
915}
916
917// pull manages the StreamingPull interaction for the life of the stream.
918func (st *stream) pull(wg *sync.WaitGroup) error {
919	errc := make(chan error, 2)
920	wg.Add(2)
921	go func() {
922		defer wg.Done()
923		errc <- st.sendLoop()
924	}()
925	go func() {
926		defer wg.Done()
927		errc <- st.recvLoop()
928	}()
929	var tchan <-chan time.Time
930	if st.timeout > 0 {
931		tchan = time.After(st.timeout)
932	}
933	// Wait until one of the goroutines returns an error, or we time out.
934	var err error
935	select {
936	case err = <-errc:
937		if err == io.EOF {
938			err = nil
939		}
940	case <-tchan:
941	}
942	close(st.done) // stop the other goroutine
943	return err
944}
945
946func (st *stream) sendLoop() error {
947	for {
948		select {
949		case <-st.done:
950			return nil
951		case rm := <-st.msgc:
952			res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
953			if err := st.gstream.Send(res); err != nil {
954				return err
955			}
956		}
957	}
958}
959
960func (st *stream) recvLoop() error {
961	for {
962		req, err := st.gstream.Recv()
963		if err != nil {
964			return err
965		}
966		st.sub.handleStreamingPullRequest(st, req)
967	}
968}
969
970func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
971	// Lock the entire server.
972	s.mu.Lock()
973	defer s.mu.Unlock()
974
975	for _, ackID := range req.AckIds {
976		s.ack(ackID)
977	}
978	for i, id := range req.ModifyDeadlineAckIds {
979		s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
980	}
981	if req.StreamAckDeadlineSeconds > 0 {
982		st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
983	}
984}
985
986// Must be called with the lock held.
987func (s *subscription) ack(id string) {
988	m := s.msgs[id]
989	if m != nil {
990		(*m.acks)++
991		delete(s.msgs, id)
992	}
993}
994
995// Must be called with the lock held.
996func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
997	m := s.msgs[id]
998	if m == nil { // already acked: ignore.
999		return
1000	}
1001	if d == 0 { // nack
1002		m.makeAvailable()
1003	} else { // extend the deadline by d
1004		m.ackDeadline = timeNow().Add(d)
1005	}
1006}
1007
1008func secsToDur(secs int32) time.Duration {
1009	return time.Duration(secs) * time.Second
1010}
1011