1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package pstest provides a fake Cloud PubSub service for testing. It implements a
16// simplified form of the service, suitable for unit tests. It may behave
17// differently from the actual service in ways in which the service is
18// non-deterministic or unspecified: timing, delivery order, etc.
19//
20// This package is EXPERIMENTAL and is subject to change without notice.
21//
22// See the example for usage.
23package pstest
24
25import (
26	"context"
27	"fmt"
28	"io"
29	"path"
30	"sort"
31	"strings"
32	"sync"
33	"sync/atomic"
34	"time"
35
36	"cloud.google.com/go/internal/testutil"
37	"github.com/golang/protobuf/ptypes"
38	durpb "github.com/golang/protobuf/ptypes/duration"
39	emptypb "github.com/golang/protobuf/ptypes/empty"
40	pb "google.golang.org/genproto/googleapis/pubsub/v1"
41	"google.golang.org/grpc/codes"
42	"google.golang.org/grpc/status"
43)
44
45// For testing. Note that even though changes to the now variable are atomic, a call
46// to the stored function can race with a change to that function. This could be a
47// problem if tests are run in parallel, or even if concurrent parts of the same test
48// change the value of the variable.
49var now atomic.Value
50
51func init() {
52	now.Store(time.Now)
53	ResetMinAckDeadline()
54}
55
56func timeNow() time.Time {
57	return now.Load().(func() time.Time)()
58}
59
60// Server is a fake Pub/Sub server.
61type Server struct {
62	srv     *testutil.Server
63	Addr    string  // The address that the server is listening on.
64	GServer GServer // Not intended to be used directly.
65}
66
67// GServer is the underlying service implementor. It is not intended to be used
68// directly.
69type GServer struct {
70	pb.PublisherServer
71	pb.SubscriberServer
72
73	mu            sync.Mutex
74	topics        map[string]*topic
75	subs          map[string]*subscription
76	msgs          []*Message // all messages ever published
77	msgsByID      map[string]*Message
78	wg            sync.WaitGroup
79	nextID        int
80	streamTimeout time.Duration
81}
82
83// NewServer creates a new fake server running in the current process.
84func NewServer() *Server {
85	srv, err := testutil.NewServer()
86	if err != nil {
87		panic(fmt.Sprintf("pstest.NewServer: %v", err))
88	}
89	s := &Server{
90		srv:  srv,
91		Addr: srv.Addr,
92		GServer: GServer{
93			topics:   map[string]*topic{},
94			subs:     map[string]*subscription{},
95			msgsByID: map[string]*Message{},
96		},
97	}
98	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
99	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
100	srv.Start()
101	return s
102}
103
104// Publish behaves as if the Publish RPC was called with a message with the given
105// data and attrs. It returns the ID of the message.
106// The topic will be created if it doesn't exist.
107//
108// Publish panics if there is an error, which is appropriate for testing.
109func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
110	const topicPattern = "projects/*/topics/*"
111	ok, err := path.Match(topicPattern, topic)
112	if err != nil {
113		panic(err)
114	}
115	if !ok {
116		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
117	}
118	_, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
119	req := &pb.PublishRequest{
120		Topic:    topic,
121		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs}},
122	}
123	res, err := s.GServer.Publish(context.TODO(), req)
124	if err != nil {
125		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
126	}
127	return res.MessageIds[0]
128}
129
130// SetStreamTimeout sets the amount of time a stream will be active before it shuts
131// itself down. This mimics the real service's behavior of closing streams after 30
132// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
133// down.
134func (s *Server) SetStreamTimeout(d time.Duration) {
135	s.GServer.mu.Lock()
136	defer s.GServer.mu.Unlock()
137	s.GServer.streamTimeout = d
138}
139
140// A Message is a message that was published to the server.
141type Message struct {
142	ID          string
143	Data        []byte
144	Attributes  map[string]string
145	PublishTime time.Time
146	Deliveries  int // number of times delivery of the message was attempted
147	Acks        int // number of acks received from clients
148
149	// protected by server mutex
150	deliveries int
151	acks       int
152	Modacks    []Modack // modacks received by server for this message
153
154}
155
156// Modack represents a modack sent to the server.
157type Modack struct {
158	AckID       string
159	AckDeadline int32
160	ReceivedAt  time.Time
161}
162
163// Messages returns information about all messages ever published.
164func (s *Server) Messages() []*Message {
165	s.GServer.mu.Lock()
166	defer s.GServer.mu.Unlock()
167
168	var msgs []*Message
169	for _, m := range s.GServer.msgs {
170		m.Deliveries = m.deliveries
171		m.Acks = m.acks
172		msgs = append(msgs, m)
173	}
174	return msgs
175}
176
177// Message returns the message with the given ID, or nil if no message
178// with that ID was published.
179func (s *Server) Message(id string) *Message {
180	s.GServer.mu.Lock()
181	defer s.GServer.mu.Unlock()
182
183	m := s.GServer.msgsByID[id]
184	if m != nil {
185		m.Deliveries = m.deliveries
186		m.Acks = m.acks
187	}
188	return m
189}
190
191// Wait blocks until all server activity has completed.
192func (s *Server) Wait() {
193	s.GServer.wg.Wait()
194}
195
196// Close shuts down the server and releases all resources.
197func (s *Server) Close() error {
198	s.srv.Close()
199	s.GServer.mu.Lock()
200	defer s.GServer.mu.Unlock()
201	for _, sub := range s.GServer.subs {
202		sub.stop()
203	}
204	return nil
205}
206
207func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
208	s.mu.Lock()
209	defer s.mu.Unlock()
210
211	if s.topics[t.Name] != nil {
212		return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
213	}
214	top := newTopic(t)
215	s.topics[t.Name] = top
216	return top.proto, nil
217}
218
219func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
220	s.mu.Lock()
221	defer s.mu.Unlock()
222
223	if t := s.topics[req.Topic]; t != nil {
224		return t.proto, nil
225	}
226	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
227}
228
229func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
230	s.mu.Lock()
231	defer s.mu.Unlock()
232
233	t := s.topics[req.Topic.Name]
234	if t == nil {
235		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
236	}
237	for _, path := range req.UpdateMask.Paths {
238		switch path {
239		case "labels":
240			t.proto.Labels = req.Topic.Labels
241		case "message_storage_policy": // "fetch" the policy
242			t.proto.MessageStoragePolicy = &pb.MessageStoragePolicy{AllowedPersistenceRegions: []string{"US"}}
243		default:
244			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
245		}
246	}
247	return t.proto, nil
248}
249
250func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
251	s.mu.Lock()
252	defer s.mu.Unlock()
253
254	var names []string
255	for n := range s.topics {
256		if strings.HasPrefix(n, req.Project) {
257			names = append(names, n)
258		}
259	}
260	sort.Strings(names)
261	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
262	if err != nil {
263		return nil, err
264	}
265	res := &pb.ListTopicsResponse{NextPageToken: nextToken}
266	for i := from; i < to; i++ {
267		res.Topics = append(res.Topics, s.topics[names[i]].proto)
268	}
269	return res, nil
270}
271
272func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
273	s.mu.Lock()
274	defer s.mu.Unlock()
275
276	var names []string
277	for name, sub := range s.subs {
278		if sub.topic.proto.Name == req.Topic {
279			names = append(names, name)
280		}
281	}
282	sort.Strings(names)
283	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
284	if err != nil {
285		return nil, err
286	}
287	return &pb.ListTopicSubscriptionsResponse{
288		Subscriptions: names[from:to],
289		NextPageToken: nextToken,
290	}, nil
291}
292
293func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
294	s.mu.Lock()
295	defer s.mu.Unlock()
296
297	t := s.topics[req.Topic]
298	if t == nil {
299		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
300	}
301	t.stop()
302	delete(s.topics, req.Topic)
303	return &emptypb.Empty{}, nil
304}
305
306func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
307	s.mu.Lock()
308	defer s.mu.Unlock()
309
310	if ps.Name == "" {
311		return nil, status.Errorf(codes.InvalidArgument, "missing name")
312	}
313	if s.subs[ps.Name] != nil {
314		return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
315	}
316	if ps.Topic == "" {
317		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
318	}
319	top := s.topics[ps.Topic]
320	if top == nil {
321		return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
322	}
323	if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
324		return nil, err
325	}
326	if ps.MessageRetentionDuration == nil {
327		ps.MessageRetentionDuration = defaultMessageRetentionDuration
328	}
329	if err := checkMRD(ps.MessageRetentionDuration); err != nil {
330		return nil, err
331	}
332	if ps.PushConfig == nil {
333		ps.PushConfig = &pb.PushConfig{}
334	}
335
336	sub := newSubscription(top, &s.mu, ps)
337	top.subs[ps.Name] = sub
338	s.subs[ps.Name] = sub
339	sub.start(&s.wg)
340	return ps, nil
341}
342
343// Can be set for testing.
344var minAckDeadlineSecs int32
345
346// SetMinAckDeadline changes the minack deadline to n. Must be
347// greater than or equal to 1 second. Remember to reset this value
348// to the default after your test changes it. Example usage:
349// 		pstest.SetMinAckDeadlineSecs(1)
350// 		defer pstest.ResetMinAckDeadlineSecs()
351func SetMinAckDeadline(n time.Duration) {
352	if n < time.Second {
353		panic("SetMinAckDeadline expects a value greater than 1 second")
354	}
355
356	minAckDeadlineSecs = int32(n / time.Second)
357}
358
359// ResetMinAckDeadline resets the minack deadline to the default.
360func ResetMinAckDeadline() {
361	minAckDeadlineSecs = 10
362}
363
364func checkAckDeadline(ads int32) error {
365	if ads < minAckDeadlineSecs || ads > 600 {
366		// PubSub service returns Unknown.
367		return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
368	}
369	return nil
370}
371
372const (
373	minMessageRetentionDuration = 10 * time.Minute
374	maxMessageRetentionDuration = 168 * time.Hour
375)
376
377var defaultMessageRetentionDuration = ptypes.DurationProto(maxMessageRetentionDuration)
378
379func checkMRD(pmrd *durpb.Duration) error {
380	mrd, err := ptypes.Duration(pmrd)
381	if err != nil || mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
382		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
383	}
384	return nil
385}
386
387func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
388	s.mu.Lock()
389	defer s.mu.Unlock()
390	sub, err := s.findSubscription(req.Subscription)
391	if err != nil {
392		return nil, err
393	}
394	return sub.proto, nil
395}
396
397func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
398	if req.Subscription == nil {
399		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
400	}
401	s.mu.Lock()
402	defer s.mu.Unlock()
403	sub, err := s.findSubscription(req.Subscription.Name)
404	if err != nil {
405		return nil, err
406	}
407	for _, path := range req.UpdateMask.Paths {
408		switch path {
409		case "push_config":
410			sub.proto.PushConfig = req.Subscription.PushConfig
411
412		case "ack_deadline_seconds":
413			a := req.Subscription.AckDeadlineSeconds
414			if err := checkAckDeadline(a); err != nil {
415				return nil, err
416			}
417			sub.proto.AckDeadlineSeconds = a
418
419		case "retain_acked_messages":
420			sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
421
422		case "message_retention_duration":
423			if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
424				return nil, err
425			}
426			sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
427
428		case "labels":
429			sub.proto.Labels = req.Subscription.Labels
430
431		case "expiration_policy":
432			sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
433
434		default:
435			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
436		}
437	}
438	return sub.proto, nil
439}
440
441func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
442	s.mu.Lock()
443	defer s.mu.Unlock()
444
445	var names []string
446	for name := range s.subs {
447		if strings.HasPrefix(name, req.Project) {
448			names = append(names, name)
449		}
450	}
451	sort.Strings(names)
452	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
453	if err != nil {
454		return nil, err
455	}
456	res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
457	for i := from; i < to; i++ {
458		res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
459	}
460	return res, nil
461}
462
463func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
464	s.mu.Lock()
465	defer s.mu.Unlock()
466	sub, err := s.findSubscription(req.Subscription)
467	if err != nil {
468		return nil, err
469	}
470	sub.stop()
471	delete(s.subs, req.Subscription)
472	sub.topic.deleteSub(sub)
473	return &emptypb.Empty{}, nil
474}
475
476func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
477	s.mu.Lock()
478	defer s.mu.Unlock()
479
480	if req.Topic == "" {
481		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
482	}
483	top := s.topics[req.Topic]
484	if top == nil {
485		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
486	}
487	var ids []string
488	for _, pm := range req.Messages {
489		id := fmt.Sprintf("m%d", s.nextID)
490		s.nextID++
491		pm.MessageId = id
492		pubTime := timeNow()
493		tsPubTime, err := ptypes.TimestampProto(pubTime)
494		if err != nil {
495			return nil, status.Errorf(codes.Internal, err.Error())
496		}
497		pm.PublishTime = tsPubTime
498		m := &Message{
499			ID:          id,
500			Data:        pm.Data,
501			Attributes:  pm.Attributes,
502			PublishTime: pubTime,
503		}
504		top.publish(pm, m)
505		ids = append(ids, id)
506		s.msgs = append(s.msgs, m)
507		s.msgsByID[id] = m
508	}
509	return &pb.PublishResponse{MessageIds: ids}, nil
510}
511
512type topic struct {
513	proto *pb.Topic
514	subs  map[string]*subscription
515}
516
517func newTopic(pt *pb.Topic) *topic {
518	return &topic{
519		proto: pt,
520		subs:  map[string]*subscription{},
521	}
522}
523
524func (t *topic) stop() {
525	for _, sub := range t.subs {
526		sub.proto.Topic = "_deleted-topic_"
527		sub.stop()
528	}
529}
530
531func (t *topic) deleteSub(sub *subscription) {
532	delete(t.subs, sub.proto.Name)
533}
534
535func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
536	for _, s := range t.subs {
537		s.msgs[pm.MessageId] = &message{
538			publishTime: m.PublishTime,
539			proto: &pb.ReceivedMessage{
540				AckId:   pm.MessageId,
541				Message: pm,
542			},
543			deliveries:  &m.deliveries,
544			acks:        &m.acks,
545			streamIndex: -1,
546		}
547	}
548}
549
550type subscription struct {
551	topic      *topic
552	mu         *sync.Mutex // the server mutex, here for convenience
553	proto      *pb.Subscription
554	ackTimeout time.Duration
555	msgs       map[string]*message // unacked messages by message ID
556	streams    []*stream
557	done       chan struct{}
558}
559
560func newSubscription(t *topic, mu *sync.Mutex, ps *pb.Subscription) *subscription {
561	at := time.Duration(ps.AckDeadlineSeconds) * time.Second
562	if at == 0 {
563		at = 10 * time.Second
564	}
565	return &subscription{
566		topic:      t,
567		mu:         mu,
568		proto:      ps,
569		ackTimeout: at,
570		msgs:       map[string]*message{},
571		done:       make(chan struct{}),
572	}
573}
574
575func (s *subscription) start(wg *sync.WaitGroup) {
576	wg.Add(1)
577	go func() {
578		defer wg.Done()
579		for {
580			select {
581			case <-s.done:
582				return
583			case <-time.After(10 * time.Millisecond):
584				s.deliver()
585			}
586		}
587	}()
588}
589
590func (s *subscription) stop() {
591	close(s.done)
592}
593
594func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
595	s.mu.Lock()
596	defer s.mu.Unlock()
597
598	sub, err := s.findSubscription(req.Subscription)
599	if err != nil {
600		return nil, err
601	}
602	for _, id := range req.AckIds {
603		sub.ack(id)
604	}
605	return &emptypb.Empty{}, nil
606}
607
608func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
609	s.mu.Lock()
610	defer s.mu.Unlock()
611	sub, err := s.findSubscription(req.Subscription)
612	if err != nil {
613		return nil, err
614	}
615	now := time.Now()
616	for _, id := range req.AckIds {
617		s.msgsByID[id].Modacks = append(s.msgsByID[id].Modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
618	}
619	dur := secsToDur(req.AckDeadlineSeconds)
620	for _, id := range req.AckIds {
621		sub.modifyAckDeadline(id, dur)
622	}
623	return &emptypb.Empty{}, nil
624}
625
626func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
627	s.mu.Lock()
628	sub, err := s.findSubscription(req.Subscription)
629	if err != nil {
630		s.mu.Unlock()
631		return nil, err
632	}
633	max := int(req.MaxMessages)
634	if max < 0 {
635		s.mu.Unlock()
636		return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
637	}
638	if max == 0 { // MaxMessages not specified; use a default.
639		max = 1000
640	}
641	msgs := sub.pull(max)
642	s.mu.Unlock()
643	// Implement the spec from the pubsub proto:
644	// "If ReturnImmediately set to true, the system will respond immediately even if
645	// it there are no messages available to return in the `Pull` response.
646	// Otherwise, the system may wait (for a bounded amount of time) until at
647	// least one message is available, rather than returning no messages."
648	if len(msgs) == 0 && !req.ReturnImmediately {
649		// Wait for a short amount of time for a message.
650		// TODO: signal when a message arrives, so we don't wait the whole time.
651		select {
652		case <-ctx.Done():
653			return nil, ctx.Err()
654		case <-time.After(500 * time.Millisecond):
655			s.mu.Lock()
656			msgs = sub.pull(max)
657			s.mu.Unlock()
658		}
659	}
660	return &pb.PullResponse{ReceivedMessages: msgs}, nil
661}
662
663func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
664	// Receive initial message configuring the pull.
665	req, err := sps.Recv()
666	if err != nil {
667		return err
668	}
669	s.mu.Lock()
670	sub, err := s.findSubscription(req.Subscription)
671	s.mu.Unlock()
672	if err != nil {
673		return err
674	}
675	// Create a new stream to handle the pull.
676	st := sub.newStream(sps, s.streamTimeout)
677	err = st.pull(&s.wg)
678	sub.deleteStream(st)
679	return err
680}
681
682func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
683	// Only handle time-based seeking for now.
684	// This fake doesn't deal with snapshots.
685	var target time.Time
686	switch v := req.Target.(type) {
687	case nil:
688		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
689	case *pb.SeekRequest_Time:
690		var err error
691		target, err = ptypes.Timestamp(v.Time)
692		if err != nil {
693			return nil, status.Errorf(codes.InvalidArgument, "bad Time target: %v", err)
694		}
695	default:
696		return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
697	}
698
699	// The entire server must be locked while doing the work below,
700	// because the messages don't have any other synchronization.
701	s.mu.Lock()
702	defer s.mu.Unlock()
703	sub, err := s.findSubscription(req.Subscription)
704	if err != nil {
705		return nil, err
706	}
707	// Drop all messages from sub that were published before the target time.
708	for id, m := range sub.msgs {
709		if m.publishTime.Before(target) {
710			delete(sub.msgs, id)
711			(*m.acks)++
712		}
713	}
714	// Un-ack any already-acked messages after this time;
715	// redelivering them to the subscription is the closest analogue here.
716	for _, m := range s.msgs {
717		if m.PublishTime.Before(target) {
718			continue
719		}
720		sub.msgs[m.ID] = &message{
721			publishTime: m.PublishTime,
722			proto: &pb.ReceivedMessage{
723				AckId: m.ID,
724				// This was not preserved!
725				//Message: pm,
726			},
727			deliveries:  &m.deliveries,
728			acks:        &m.acks,
729			streamIndex: -1,
730		}
731	}
732	return &pb.SeekResponse{}, nil
733}
734
735// Gets a subscription that must exist.
736// Must be called with the lock held.
737func (s *GServer) findSubscription(name string) (*subscription, error) {
738	if name == "" {
739		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
740	}
741	sub := s.subs[name]
742	if sub == nil {
743		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
744	}
745	return sub, nil
746}
747
748// Must be called with the lock held.
749func (s *subscription) pull(max int) []*pb.ReceivedMessage {
750	now := timeNow()
751	s.maintainMessages(now)
752	var msgs []*pb.ReceivedMessage
753	for _, m := range s.msgs {
754		if m.outstanding() {
755			continue
756		}
757		(*m.deliveries)++
758		m.ackDeadline = now.Add(s.ackTimeout)
759		msgs = append(msgs, m.proto)
760		if len(msgs) >= max {
761			break
762		}
763	}
764	return msgs
765}
766
767func (s *subscription) deliver() {
768	s.mu.Lock()
769	defer s.mu.Unlock()
770
771	now := timeNow()
772	s.maintainMessages(now)
773	// Try to deliver each remaining message.
774	curIndex := 0
775	for _, m := range s.msgs {
776		if m.outstanding() {
777			continue
778		}
779		// If the message was never delivered before, start with the stream at
780		// curIndex. If it was delivered before, start with the stream after the one
781		// that owned it.
782		if m.streamIndex < 0 {
783			delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
784			if !ok {
785				break
786			}
787			curIndex = delIndex + 1
788			m.streamIndex = curIndex
789		} else {
790			delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
791			if !ok {
792				break
793			}
794			m.streamIndex = delIndex
795		}
796	}
797}
798
799// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
800// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
801// exits.
802//
803// It returns the index of the stream it delivered the message to, or 0, false if
804// it didn't deliver the message.
805//
806// Must be called with the lock held.
807func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
808	for i := 0; i < len(s.streams); i++ {
809		idx := (i + start) % len(s.streams)
810
811		st := s.streams[idx]
812		select {
813		case <-st.done:
814			s.streams = deleteStreamAt(s.streams, idx)
815			i--
816
817		case st.msgc <- m.proto:
818			(*m.deliveries)++
819			m.ackDeadline = now.Add(st.ackTimeout)
820			return idx, true
821
822		default:
823		}
824	}
825	return 0, false
826}
827
828var retentionDuration = 10 * time.Minute
829
830// Must be called with the lock held.
831func (s *subscription) maintainMessages(now time.Time) {
832	for id, m := range s.msgs {
833		// Mark a message as re-deliverable if its ack deadline has expired.
834		if m.outstanding() && now.After(m.ackDeadline) {
835			m.makeAvailable()
836		}
837		pubTime, err := ptypes.Timestamp(m.proto.Message.PublishTime)
838		if err != nil {
839			panic(err)
840		}
841		// Remove messages that have been undelivered for a long time.
842		if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
843			delete(s.msgs, id)
844		}
845	}
846}
847
848func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
849	st := &stream{
850		sub:        s,
851		done:       make(chan struct{}),
852		msgc:       make(chan *pb.ReceivedMessage),
853		gstream:    gs,
854		ackTimeout: s.ackTimeout,
855		timeout:    timeout,
856	}
857	s.mu.Lock()
858	s.streams = append(s.streams, st)
859	s.mu.Unlock()
860	return st
861}
862
863func (s *subscription) deleteStream(st *stream) {
864	s.mu.Lock()
865	defer s.mu.Unlock()
866	var i int
867	for i = 0; i < len(s.streams); i++ {
868		if s.streams[i] == st {
869			break
870		}
871	}
872	if i < len(s.streams) {
873		s.streams = deleteStreamAt(s.streams, i)
874	}
875}
876func deleteStreamAt(s []*stream, i int) []*stream {
877	// Preserve order for round-robin delivery.
878	return append(s[:i], s[i+1:]...)
879}
880
881type message struct {
882	proto       *pb.ReceivedMessage
883	publishTime time.Time
884	ackDeadline time.Time
885	deliveries  *int
886	acks        *int
887	streamIndex int // index of stream that currently owns msg, for round-robin delivery
888}
889
890// A message is outstanding if it is owned by some stream.
891func (m *message) outstanding() bool {
892	return !m.ackDeadline.IsZero()
893}
894
895func (m *message) makeAvailable() {
896	m.ackDeadline = time.Time{}
897}
898
899type stream struct {
900	sub        *subscription
901	done       chan struct{} // closed when the stream is finished
902	msgc       chan *pb.ReceivedMessage
903	gstream    pb.Subscriber_StreamingPullServer
904	ackTimeout time.Duration
905	timeout    time.Duration
906}
907
908// pull manages the StreamingPull interaction for the life of the stream.
909func (st *stream) pull(wg *sync.WaitGroup) error {
910	errc := make(chan error, 2)
911	wg.Add(2)
912	go func() {
913		defer wg.Done()
914		errc <- st.sendLoop()
915	}()
916	go func() {
917		defer wg.Done()
918		errc <- st.recvLoop()
919	}()
920	var tchan <-chan time.Time
921	if st.timeout > 0 {
922		tchan = time.After(st.timeout)
923	}
924	// Wait until one of the goroutines returns an error, or we time out.
925	var err error
926	select {
927	case err = <-errc:
928		if err == io.EOF {
929			err = nil
930		}
931	case <-tchan:
932	}
933	close(st.done) // stop the other goroutine
934	return err
935}
936
937func (st *stream) sendLoop() error {
938	for {
939		select {
940		case <-st.done:
941			return nil
942		case rm := <-st.msgc:
943			res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
944			if err := st.gstream.Send(res); err != nil {
945				return err
946			}
947		}
948	}
949}
950
951func (st *stream) recvLoop() error {
952	for {
953		req, err := st.gstream.Recv()
954		if err != nil {
955			return err
956		}
957		st.sub.handleStreamingPullRequest(st, req)
958	}
959}
960
961func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
962	// Lock the entire server.
963	s.mu.Lock()
964	defer s.mu.Unlock()
965
966	for _, ackID := range req.AckIds {
967		s.ack(ackID)
968	}
969	for i, id := range req.ModifyDeadlineAckIds {
970		s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
971	}
972	if req.StreamAckDeadlineSeconds > 0 {
973		st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
974	}
975}
976
977// Must be called with the lock held.
978func (s *subscription) ack(id string) {
979	m := s.msgs[id]
980	if m != nil {
981		(*m.acks)++
982		delete(s.msgs, id)
983	}
984}
985
986// Must be called with the lock held.
987func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
988	m := s.msgs[id]
989	if m == nil { // already acked: ignore.
990		return
991	}
992	if d == 0 { // nack
993		m.makeAvailable()
994	} else { // extend the deadline by d
995		m.ackDeadline = timeNow().Add(d)
996	}
997}
998
999func secsToDur(secs int32) time.Duration {
1000	return time.Duration(secs) * time.Second
1001}
1002