1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package pstest provides a fake Cloud PubSub service for testing. It implements a
16// simplified form of the service, suitable for unit tests. It may behave
17// differently from the actual service in ways in which the service is
18// non-deterministic or unspecified: timing, delivery order, etc.
19//
20// This package is EXPERIMENTAL and is subject to change without notice.
21//
22// See the example for usage.
23package pstest
24
25import (
26	"context"
27	"fmt"
28	"io"
29	"path"
30	"sort"
31	"strings"
32	"sync"
33	"sync/atomic"
34	"time"
35
36	"cloud.google.com/go/internal/testutil"
37	pb "google.golang.org/genproto/googleapis/pubsub/v1"
38	"google.golang.org/grpc/codes"
39	"google.golang.org/grpc/status"
40	durpb "google.golang.org/protobuf/types/known/durationpb"
41	"google.golang.org/protobuf/types/known/emptypb"
42	"google.golang.org/protobuf/types/known/timestamppb"
43)
44
45// ReactorOptions is a map that Server uses to look up reactors.
46// Key is the function name, value is array of reactor for the function.
47type ReactorOptions map[string][]Reactor
48
49// Reactor is an interface to allow reaction function to a certain call.
50type Reactor interface {
51	// React handles the message types and returns results.  If "handled" is false,
52	// then the test server will ignore the results and continue to the next reactor
53	// or the original handler.
54	React(_ interface{}) (handled bool, ret interface{}, err error)
55}
56
57// ServerReactorOption is options passed to the server for reactor creation.
58type ServerReactorOption struct {
59	FuncName string
60	Reactor  Reactor
61}
62
63type publishResponse struct {
64	resp *pb.PublishResponse
65	err  error
66}
67
68// For testing. Note that even though changes to the now variable are atomic, a call
69// to the stored function can race with a change to that function. This could be a
70// problem if tests are run in parallel, or even if concurrent parts of the same test
71// change the value of the variable.
72var now atomic.Value
73
74func init() {
75	now.Store(time.Now)
76	ResetMinAckDeadline()
77}
78
79func timeNow() time.Time {
80	return now.Load().(func() time.Time)()
81}
82
83// Server is a fake Pub/Sub server.
84type Server struct {
85	srv     *testutil.Server
86	Addr    string  // The address that the server is listening on.
87	GServer GServer // Not intended to be used directly.
88}
89
90// GServer is the underlying service implementor. It is not intended to be used
91// directly.
92type GServer struct {
93	pb.PublisherServer
94	pb.SubscriberServer
95
96	mu             sync.Mutex
97	topics         map[string]*topic
98	subs           map[string]*subscription
99	msgs           []*Message // all messages ever published
100	msgsByID       map[string]*Message
101	wg             sync.WaitGroup
102	nextID         int
103	streamTimeout  time.Duration
104	timeNowFunc    func() time.Time
105	reactorOptions ReactorOptions
106	schemas        map[string]*pb.Schema
107
108	// PublishResponses is a channel of responses to use for Publish.
109	publishResponses chan *publishResponse
110	// autoPublishResponse enables the server to automatically generate
111	// PublishResponse when publish is called. Otherwise, responses
112	// are generated from the publishResponses channel.
113	autoPublishResponse bool
114}
115
116// NewServer creates a new fake server running in the current process.
117func NewServer(opts ...ServerReactorOption) *Server {
118	return NewServerWithPort(0, opts...)
119}
120
121// NewServerWithPort creates a new fake server running in the current process at the specified port.
122func NewServerWithPort(port int, opts ...ServerReactorOption) *Server {
123	srv, err := testutil.NewServerWithPort(port)
124	if err != nil {
125		panic(fmt.Sprintf("pstest.NewServerWithPort: %v", err))
126	}
127	reactorOptions := ReactorOptions{}
128	for _, opt := range opts {
129		reactorOptions[opt.FuncName] = append(reactorOptions[opt.FuncName], opt.Reactor)
130	}
131	s := &Server{
132		srv:  srv,
133		Addr: srv.Addr,
134		GServer: GServer{
135			topics:              map[string]*topic{},
136			subs:                map[string]*subscription{},
137			msgsByID:            map[string]*Message{},
138			timeNowFunc:         timeNow,
139			reactorOptions:      reactorOptions,
140			publishResponses:    make(chan *publishResponse, 100),
141			autoPublishResponse: true,
142			schemas:             map[string]*pb.Schema{},
143		},
144	}
145	pb.RegisterPublisherServer(srv.Gsrv, &s.GServer)
146	pb.RegisterSubscriberServer(srv.Gsrv, &s.GServer)
147	pb.RegisterSchemaServiceServer(srv.Gsrv, &s.GServer)
148	srv.Start()
149	return s
150}
151
152// SetTimeNowFunc registers f as a function to
153// be used instead of time.Now for this server.
154func (s *Server) SetTimeNowFunc(f func() time.Time) {
155	s.GServer.timeNowFunc = f
156}
157
158// Publish behaves as if the Publish RPC was called with a message with the given
159// data and attrs. It returns the ID of the message.
160// The topic will be created if it doesn't exist.
161//
162// Publish panics if there is an error, which is appropriate for testing.
163func (s *Server) Publish(topic string, data []byte, attrs map[string]string) string {
164	return s.PublishOrdered(topic, data, attrs, "")
165}
166
167// PublishOrdered behaves as if the Publish RPC was called with a message with the given
168// data, attrs and ordering key. It returns the ID of the message.
169// The topic will be created if it doesn't exist.
170//
171// PublishOrdered panics if there is an error, which is appropriate for testing.
172func (s *Server) PublishOrdered(topic string, data []byte, attrs map[string]string, orderingKey string) string {
173	const topicPattern = "projects/*/topics/*"
174	ok, err := path.Match(topicPattern, topic)
175	if err != nil {
176		panic(err)
177	}
178	if !ok {
179		panic(fmt.Sprintf("topic name must be of the form %q", topicPattern))
180	}
181	_, _ = s.GServer.CreateTopic(context.TODO(), &pb.Topic{Name: topic})
182	req := &pb.PublishRequest{
183		Topic:    topic,
184		Messages: []*pb.PubsubMessage{{Data: data, Attributes: attrs, OrderingKey: orderingKey}},
185	}
186	res, err := s.GServer.Publish(context.TODO(), req)
187	if err != nil {
188		panic(fmt.Sprintf("pstest.Server.Publish: %v", err))
189	}
190	return res.MessageIds[0]
191}
192
193// AddPublishResponse adds a new publish response to the channel used for
194// responding to publish requests.
195func (s *Server) AddPublishResponse(pbr *pb.PublishResponse, err error) {
196	pr := &publishResponse{}
197	if err != nil {
198		pr.err = err
199	} else {
200		pr.resp = pbr
201	}
202	s.GServer.publishResponses <- pr
203}
204
205// SetAutoPublishResponse controls whether to automatically respond
206// to messages published or to use user-added responses from the
207// publishResponses channel.
208func (s *Server) SetAutoPublishResponse(autoPublishResponse bool) {
209	s.GServer.mu.Lock()
210	defer s.GServer.mu.Unlock()
211	s.GServer.autoPublishResponse = autoPublishResponse
212}
213
214// ResetPublishResponses resets the buffered publishResponses channel
215// with a new buffered channel with the given size.
216func (s *Server) ResetPublishResponses(size int) {
217	s.GServer.mu.Lock()
218	defer s.GServer.mu.Unlock()
219	s.GServer.publishResponses = make(chan *publishResponse, size)
220}
221
222// SetStreamTimeout sets the amount of time a stream will be active before it shuts
223// itself down. This mimics the real service's behavior of closing streams after 30
224// minutes. If SetStreamTimeout is never called or is passed zero, streams never shut
225// down.
226func (s *Server) SetStreamTimeout(d time.Duration) {
227	s.GServer.mu.Lock()
228	defer s.GServer.mu.Unlock()
229	s.GServer.streamTimeout = d
230}
231
232// A Message is a message that was published to the server.
233type Message struct {
234	ID          string
235	Data        []byte
236	Attributes  map[string]string
237	PublishTime time.Time
238	Deliveries  int      // number of times delivery of the message was attempted
239	Acks        int      // number of acks received from clients
240	Modacks     []Modack // modacks received by server for this message
241	OrderingKey string
242
243	// protected by server mutex
244	deliveries int
245	acks       int
246	modacks    []Modack
247}
248
249// Modack represents a modack sent to the server.
250type Modack struct {
251	AckID       string
252	AckDeadline int32
253	ReceivedAt  time.Time
254}
255
256// Messages returns information about all messages ever published.
257func (s *Server) Messages() []*Message {
258	s.GServer.mu.Lock()
259	defer s.GServer.mu.Unlock()
260
261	var msgs []*Message
262	for _, m := range s.GServer.msgs {
263		m.Deliveries = m.deliveries
264		m.Acks = m.acks
265		m.Modacks = append([]Modack(nil), m.modacks...)
266		msgs = append(msgs, m)
267	}
268	return msgs
269}
270
271// Message returns the message with the given ID, or nil if no message
272// with that ID was published.
273func (s *Server) Message(id string) *Message {
274	s.GServer.mu.Lock()
275	defer s.GServer.mu.Unlock()
276
277	m := s.GServer.msgsByID[id]
278	if m != nil {
279		m.Deliveries = m.deliveries
280		m.Acks = m.acks
281		m.Modacks = append([]Modack(nil), m.modacks...)
282	}
283	return m
284}
285
286// Wait blocks until all server activity has completed.
287func (s *Server) Wait() {
288	s.GServer.wg.Wait()
289}
290
291// ClearMessages removes all published messages
292// from internal containers.
293func (s *Server) ClearMessages() {
294	s.GServer.mu.Lock()
295	s.GServer.msgs = nil
296	s.GServer.msgsByID = make(map[string]*Message)
297	s.GServer.mu.Unlock()
298}
299
300// Close shuts down the server and releases all resources.
301func (s *Server) Close() error {
302	s.srv.Close()
303	s.GServer.mu.Lock()
304	defer s.GServer.mu.Unlock()
305	for _, sub := range s.GServer.subs {
306		sub.stop()
307	}
308	return nil
309}
310
311func (s *GServer) CreateTopic(_ context.Context, t *pb.Topic) (*pb.Topic, error) {
312	s.mu.Lock()
313	defer s.mu.Unlock()
314
315	if handled, ret, err := s.runReactor(t, "CreateTopic", &pb.Topic{}); handled || err != nil {
316		return ret.(*pb.Topic), err
317	}
318
319	if s.topics[t.Name] != nil {
320		return nil, status.Errorf(codes.AlreadyExists, "topic %q", t.Name)
321	}
322	top := newTopic(t)
323	s.topics[t.Name] = top
324	return top.proto, nil
325}
326
327func (s *GServer) GetTopic(_ context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
328	s.mu.Lock()
329	defer s.mu.Unlock()
330
331	if handled, ret, err := s.runReactor(req, "GetTopic", &pb.Topic{}); handled || err != nil {
332		return ret.(*pb.Topic), err
333	}
334
335	if t := s.topics[req.Topic]; t != nil {
336		return t.proto, nil
337	}
338	return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
339}
340
341func (s *GServer) UpdateTopic(_ context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
342	s.mu.Lock()
343	defer s.mu.Unlock()
344
345	if handled, ret, err := s.runReactor(req, "UpdateTopic", &pb.Topic{}); handled || err != nil {
346		return ret.(*pb.Topic), err
347	}
348
349	t := s.topics[req.Topic.Name]
350	if t == nil {
351		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic.Name)
352	}
353	for _, path := range req.UpdateMask.Paths {
354		switch path {
355		case "labels":
356			t.proto.Labels = req.Topic.Labels
357		case "message_storage_policy":
358			t.proto.MessageStoragePolicy = req.Topic.MessageStoragePolicy
359		default:
360			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
361		}
362	}
363	return t.proto, nil
364}
365
366func (s *GServer) ListTopics(_ context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
367	s.mu.Lock()
368	defer s.mu.Unlock()
369
370	if handled, ret, err := s.runReactor(req, "ListTopics", &pb.ListTopicsResponse{}); handled || err != nil {
371		return ret.(*pb.ListTopicsResponse), err
372	}
373
374	var names []string
375	for n := range s.topics {
376		if strings.HasPrefix(n, req.Project) {
377			names = append(names, n)
378		}
379	}
380	sort.Strings(names)
381	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
382	if err != nil {
383		return nil, err
384	}
385	res := &pb.ListTopicsResponse{NextPageToken: nextToken}
386	for i := from; i < to; i++ {
387		res.Topics = append(res.Topics, s.topics[names[i]].proto)
388	}
389	return res, nil
390}
391
392func (s *GServer) ListTopicSubscriptions(_ context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
393	s.mu.Lock()
394	defer s.mu.Unlock()
395
396	if handled, ret, err := s.runReactor(req, "ListTopicSubscriptions", &pb.ListTopicSubscriptionsResponse{}); handled || err != nil {
397		return ret.(*pb.ListTopicSubscriptionsResponse), err
398	}
399
400	var names []string
401	for name, sub := range s.subs {
402		if sub.topic.proto.Name == req.Topic {
403			names = append(names, name)
404		}
405	}
406	sort.Strings(names)
407	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
408	if err != nil {
409		return nil, err
410	}
411	return &pb.ListTopicSubscriptionsResponse{
412		Subscriptions: names[from:to],
413		NextPageToken: nextToken,
414	}, nil
415}
416
417func (s *GServer) DeleteTopic(_ context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
418	s.mu.Lock()
419	defer s.mu.Unlock()
420
421	if handled, ret, err := s.runReactor(req, "DeleteTopic", &emptypb.Empty{}); handled || err != nil {
422		return ret.(*emptypb.Empty), err
423	}
424
425	t := s.topics[req.Topic]
426	if t == nil {
427		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
428	}
429	t.stop()
430	delete(s.topics, req.Topic)
431	return &emptypb.Empty{}, nil
432}
433
434func (s *GServer) CreateSubscription(_ context.Context, ps *pb.Subscription) (*pb.Subscription, error) {
435	s.mu.Lock()
436	defer s.mu.Unlock()
437
438	if handled, ret, err := s.runReactor(ps, "CreateSubscription", &pb.Subscription{}); handled || err != nil {
439		return ret.(*pb.Subscription), err
440	}
441
442	if ps.Name == "" {
443		return nil, status.Errorf(codes.InvalidArgument, "missing name")
444	}
445	if s.subs[ps.Name] != nil {
446		return nil, status.Errorf(codes.AlreadyExists, "subscription %q", ps.Name)
447	}
448	if ps.Topic == "" {
449		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
450	}
451	top := s.topics[ps.Topic]
452	if top == nil {
453		return nil, status.Errorf(codes.NotFound, "topic %q", ps.Topic)
454	}
455	if err := checkAckDeadline(ps.AckDeadlineSeconds); err != nil {
456		return nil, err
457	}
458	if ps.MessageRetentionDuration == nil {
459		ps.MessageRetentionDuration = defaultMessageRetentionDuration
460	}
461	if err := checkMRD(ps.MessageRetentionDuration); err != nil {
462		return nil, err
463	}
464	if ps.PushConfig == nil {
465		ps.PushConfig = &pb.PushConfig{}
466	}
467
468	sub := newSubscription(top, &s.mu, s.timeNowFunc, ps)
469	top.subs[ps.Name] = sub
470	s.subs[ps.Name] = sub
471	sub.start(&s.wg)
472	return ps, nil
473}
474
475// Can be set for testing.
476var minAckDeadlineSecs int32
477
478// SetMinAckDeadline changes the minack deadline to n. Must be
479// greater than or equal to 1 second. Remember to reset this value
480// to the default after your test changes it. Example usage:
481// 		pstest.SetMinAckDeadlineSecs(1)
482// 		defer pstest.ResetMinAckDeadlineSecs()
483func SetMinAckDeadline(n time.Duration) {
484	if n < time.Second {
485		panic("SetMinAckDeadline expects a value greater than 1 second")
486	}
487
488	minAckDeadlineSecs = int32(n / time.Second)
489}
490
491// ResetMinAckDeadline resets the minack deadline to the default.
492func ResetMinAckDeadline() {
493	minAckDeadlineSecs = 10
494}
495
496func checkAckDeadline(ads int32) error {
497	if ads < minAckDeadlineSecs || ads > 600 {
498		// PubSub service returns Unknown.
499		return status.Errorf(codes.Unknown, "bad ack_deadline_seconds: %d", ads)
500	}
501	return nil
502}
503
504const (
505	minMessageRetentionDuration = 10 * time.Minute
506	maxMessageRetentionDuration = 168 * time.Hour
507)
508
509var defaultMessageRetentionDuration = durpb.New(maxMessageRetentionDuration)
510
511func checkMRD(pmrd *durpb.Duration) error {
512	mrd := pmrd.AsDuration()
513	if mrd < minMessageRetentionDuration || mrd > maxMessageRetentionDuration {
514		return status.Errorf(codes.InvalidArgument, "bad message_retention_duration %+v", pmrd)
515	}
516	return nil
517}
518
519func (s *GServer) GetSubscription(_ context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
520	s.mu.Lock()
521	defer s.mu.Unlock()
522
523	if handled, ret, err := s.runReactor(req, "GetSubscription", &pb.Subscription{}); handled || err != nil {
524		return ret.(*pb.Subscription), err
525	}
526
527	sub, err := s.findSubscription(req.Subscription)
528	if err != nil {
529		return nil, err
530	}
531	return sub.proto, nil
532}
533
534func (s *GServer) UpdateSubscription(_ context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
535	if req.Subscription == nil {
536		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
537	}
538	s.mu.Lock()
539	defer s.mu.Unlock()
540
541	if handled, ret, err := s.runReactor(req, "UpdateSubscription", &pb.Subscription{}); handled || err != nil {
542		return ret.(*pb.Subscription), err
543	}
544
545	sub, err := s.findSubscription(req.Subscription.Name)
546	if err != nil {
547		return nil, err
548	}
549	for _, path := range req.UpdateMask.Paths {
550		switch path {
551		case "push_config":
552			sub.proto.PushConfig = req.Subscription.PushConfig
553
554		case "ack_deadline_seconds":
555			a := req.Subscription.AckDeadlineSeconds
556			if err := checkAckDeadline(a); err != nil {
557				return nil, err
558			}
559			sub.proto.AckDeadlineSeconds = a
560
561		case "retain_acked_messages":
562			sub.proto.RetainAckedMessages = req.Subscription.RetainAckedMessages
563
564		case "message_retention_duration":
565			if err := checkMRD(req.Subscription.MessageRetentionDuration); err != nil {
566				return nil, err
567			}
568			sub.proto.MessageRetentionDuration = req.Subscription.MessageRetentionDuration
569
570		case "labels":
571			sub.proto.Labels = req.Subscription.Labels
572
573		case "expiration_policy":
574			sub.proto.ExpirationPolicy = req.Subscription.ExpirationPolicy
575
576		case "dead_letter_policy":
577			sub.proto.DeadLetterPolicy = req.Subscription.DeadLetterPolicy
578
579		case "retry_policy":
580			sub.proto.RetryPolicy = req.Subscription.RetryPolicy
581
582		case "filter":
583			sub.proto.Filter = req.Subscription.Filter
584
585		default:
586			return nil, status.Errorf(codes.InvalidArgument, "unknown field name %q", path)
587		}
588	}
589	return sub.proto, nil
590}
591
592func (s *GServer) ListSubscriptions(_ context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
593	s.mu.Lock()
594	defer s.mu.Unlock()
595
596	if handled, ret, err := s.runReactor(req, "ListSubscriptions", &pb.ListSubscriptionsResponse{}); handled || err != nil {
597		return ret.(*pb.ListSubscriptionsResponse), err
598	}
599
600	var names []string
601	for name := range s.subs {
602		if strings.HasPrefix(name, req.Project) {
603			names = append(names, name)
604		}
605	}
606	sort.Strings(names)
607	from, to, nextToken, err := testutil.PageBounds(int(req.PageSize), req.PageToken, len(names))
608	if err != nil {
609		return nil, err
610	}
611	res := &pb.ListSubscriptionsResponse{NextPageToken: nextToken}
612	for i := from; i < to; i++ {
613		res.Subscriptions = append(res.Subscriptions, s.subs[names[i]].proto)
614	}
615	return res, nil
616}
617
618func (s *GServer) DeleteSubscription(_ context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
619	s.mu.Lock()
620	defer s.mu.Unlock()
621
622	if handled, ret, err := s.runReactor(req, "DeleteSubscription", &emptypb.Empty{}); handled || err != nil {
623		return ret.(*emptypb.Empty), err
624	}
625
626	sub, err := s.findSubscription(req.Subscription)
627	if err != nil {
628		return nil, err
629	}
630	sub.stop()
631	delete(s.subs, req.Subscription)
632	sub.topic.deleteSub(sub)
633	return &emptypb.Empty{}, nil
634}
635
636func (s *GServer) DetachSubscription(_ context.Context, req *pb.DetachSubscriptionRequest) (*pb.DetachSubscriptionResponse, error) {
637	s.mu.Lock()
638	defer s.mu.Unlock()
639
640	if handled, ret, err := s.runReactor(req, "DetachSubscription", &pb.DetachSubscriptionResponse{}); handled || err != nil {
641		return ret.(*pb.DetachSubscriptionResponse), err
642	}
643
644	sub, err := s.findSubscription(req.Subscription)
645	if err != nil {
646		return nil, err
647	}
648	sub.topic.deleteSub(sub)
649	return &pb.DetachSubscriptionResponse{}, nil
650}
651
652func (s *GServer) Publish(_ context.Context, req *pb.PublishRequest) (*pb.PublishResponse, error) {
653	s.mu.Lock()
654	defer s.mu.Unlock()
655
656	if handled, ret, err := s.runReactor(req, "Publish", &pb.PublishResponse{}); handled || err != nil {
657		return ret.(*pb.PublishResponse), err
658	}
659
660	if req.Topic == "" {
661		return nil, status.Errorf(codes.InvalidArgument, "missing topic")
662	}
663	top := s.topics[req.Topic]
664	if top == nil {
665		return nil, status.Errorf(codes.NotFound, "topic %q", req.Topic)
666	}
667
668	if !s.autoPublishResponse {
669		r := <-s.publishResponses
670		if r.err != nil {
671			return nil, r.err
672		}
673		return r.resp, nil
674	}
675
676	var ids []string
677	for _, pm := range req.Messages {
678		id := fmt.Sprintf("m%d", s.nextID)
679		s.nextID++
680		pm.MessageId = id
681		pubTime := s.timeNowFunc()
682		tsPubTime := timestamppb.New(pubTime)
683		pm.PublishTime = tsPubTime
684		m := &Message{
685			ID:          id,
686			Data:        pm.Data,
687			Attributes:  pm.Attributes,
688			PublishTime: pubTime,
689			OrderingKey: pm.OrderingKey,
690		}
691		top.publish(pm, m)
692		ids = append(ids, id)
693		s.msgs = append(s.msgs, m)
694		s.msgsByID[id] = m
695	}
696	return &pb.PublishResponse{MessageIds: ids}, nil
697}
698
699type topic struct {
700	proto *pb.Topic
701	subs  map[string]*subscription
702}
703
704func newTopic(pt *pb.Topic) *topic {
705	return &topic{
706		proto: pt,
707		subs:  map[string]*subscription{},
708	}
709}
710
711func (t *topic) stop() {
712	for _, sub := range t.subs {
713		sub.proto.Topic = "_deleted-topic_"
714	}
715}
716
717func (t *topic) deleteSub(sub *subscription) {
718	delete(t.subs, sub.proto.Name)
719}
720
721func (t *topic) publish(pm *pb.PubsubMessage, m *Message) {
722	for _, s := range t.subs {
723		s.msgs[pm.MessageId] = &message{
724			publishTime: m.PublishTime,
725			proto: &pb.ReceivedMessage{
726				AckId:   pm.MessageId,
727				Message: pm,
728			},
729			deliveries:  &m.deliveries,
730			acks:        &m.acks,
731			streamIndex: -1,
732		}
733	}
734}
735
736type subscription struct {
737	topic       *topic
738	mu          *sync.Mutex // the server mutex, here for convenience
739	proto       *pb.Subscription
740	ackTimeout  time.Duration
741	msgs        map[string]*message // unacked messages by message ID
742	streams     []*stream
743	done        chan struct{}
744	timeNowFunc func() time.Time
745}
746
747func newSubscription(t *topic, mu *sync.Mutex, timeNowFunc func() time.Time, ps *pb.Subscription) *subscription {
748	at := time.Duration(ps.AckDeadlineSeconds) * time.Second
749	if at == 0 {
750		at = 10 * time.Second
751	}
752	return &subscription{
753		topic:       t,
754		mu:          mu,
755		proto:       ps,
756		ackTimeout:  at,
757		msgs:        map[string]*message{},
758		done:        make(chan struct{}),
759		timeNowFunc: timeNowFunc,
760	}
761}
762
763func (s *subscription) start(wg *sync.WaitGroup) {
764	wg.Add(1)
765	go func() {
766		defer wg.Done()
767		for {
768			select {
769			case <-s.done:
770				return
771			case <-time.After(10 * time.Millisecond):
772				s.deliver()
773			}
774		}
775	}()
776}
777
778func (s *subscription) stop() {
779	close(s.done)
780}
781
782func (s *GServer) Acknowledge(_ context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
783	s.mu.Lock()
784	defer s.mu.Unlock()
785
786	if handled, ret, err := s.runReactor(req, "Acknowledge", &emptypb.Empty{}); handled || err != nil {
787		return ret.(*emptypb.Empty), err
788	}
789
790	sub, err := s.findSubscription(req.Subscription)
791	if err != nil {
792		return nil, err
793	}
794	for _, id := range req.AckIds {
795		sub.ack(id)
796	}
797	return &emptypb.Empty{}, nil
798}
799
800func (s *GServer) ModifyAckDeadline(_ context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
801	s.mu.Lock()
802	defer s.mu.Unlock()
803
804	if handled, ret, err := s.runReactor(req, "ModifyAckDeadline", &emptypb.Empty{}); handled || err != nil {
805		return ret.(*emptypb.Empty), err
806	}
807
808	sub, err := s.findSubscription(req.Subscription)
809	if err != nil {
810		return nil, err
811	}
812	now := time.Now()
813	for _, id := range req.AckIds {
814		s.msgsByID[id].modacks = append(s.msgsByID[id].modacks, Modack{AckID: id, AckDeadline: req.AckDeadlineSeconds, ReceivedAt: now})
815	}
816	dur := secsToDur(req.AckDeadlineSeconds)
817	for _, id := range req.AckIds {
818		sub.modifyAckDeadline(id, dur)
819	}
820	return &emptypb.Empty{}, nil
821}
822
823func (s *GServer) Pull(ctx context.Context, req *pb.PullRequest) (*pb.PullResponse, error) {
824	s.mu.Lock()
825
826	if handled, ret, err := s.runReactor(req, "Pull", &pb.PullResponse{}); handled || err != nil {
827		s.mu.Unlock()
828		return ret.(*pb.PullResponse), err
829	}
830
831	sub, err := s.findSubscription(req.Subscription)
832	if err != nil {
833		s.mu.Unlock()
834		return nil, err
835	}
836	max := int(req.MaxMessages)
837	if max < 0 {
838		s.mu.Unlock()
839		return nil, status.Error(codes.InvalidArgument, "MaxMessages cannot be negative")
840	}
841	if max == 0 { // MaxMessages not specified; use a default.
842		max = 1000
843	}
844	msgs := sub.pull(max)
845	s.mu.Unlock()
846	// Implement the spec from the pubsub proto:
847	// "If ReturnImmediately set to true, the system will respond immediately even if
848	// it there are no messages available to return in the `Pull` response.
849	// Otherwise, the system may wait (for a bounded amount of time) until at
850	// least one message is available, rather than returning no messages."
851	if len(msgs) == 0 && !req.ReturnImmediately {
852		// Wait for a short amount of time for a message.
853		// TODO: signal when a message arrives, so we don't wait the whole time.
854		select {
855		case <-ctx.Done():
856			return nil, ctx.Err()
857		case <-time.After(500 * time.Millisecond):
858			s.mu.Lock()
859			msgs = sub.pull(max)
860			s.mu.Unlock()
861		}
862	}
863	return &pb.PullResponse{ReceivedMessages: msgs}, nil
864}
865
866func (s *GServer) StreamingPull(sps pb.Subscriber_StreamingPullServer) error {
867	// Receive initial message configuring the pull.
868	req, err := sps.Recv()
869	if err != nil {
870		return err
871	}
872	s.mu.Lock()
873	sub, err := s.findSubscription(req.Subscription)
874	s.mu.Unlock()
875	if err != nil {
876		return err
877	}
878	// Create a new stream to handle the pull.
879	st := sub.newStream(sps, s.streamTimeout)
880	err = st.pull(&s.wg)
881	sub.deleteStream(st)
882	return err
883}
884
885func (s *GServer) Seek(ctx context.Context, req *pb.SeekRequest) (*pb.SeekResponse, error) {
886	// Only handle time-based seeking for now.
887	// This fake doesn't deal with snapshots.
888	var target time.Time
889	switch v := req.Target.(type) {
890	case nil:
891		return nil, status.Errorf(codes.InvalidArgument, "missing Seek target type")
892	case *pb.SeekRequest_Time:
893		target = v.Time.AsTime()
894	default:
895		return nil, status.Errorf(codes.Unimplemented, "unhandled Seek target type %T", v)
896	}
897
898	// The entire server must be locked while doing the work below,
899	// because the messages don't have any other synchronization.
900	s.mu.Lock()
901	defer s.mu.Unlock()
902
903	if handled, ret, err := s.runReactor(req, "Seek", &pb.SeekResponse{}); handled || err != nil {
904		return ret.(*pb.SeekResponse), err
905	}
906
907	sub, err := s.findSubscription(req.Subscription)
908	if err != nil {
909		return nil, err
910	}
911	// Drop all messages from sub that were published before the target time.
912	for id, m := range sub.msgs {
913		if m.publishTime.Before(target) {
914			delete(sub.msgs, id)
915			(*m.acks)++
916		}
917	}
918	// Un-ack any already-acked messages after this time;
919	// redelivering them to the subscription is the closest analogue here.
920	for _, m := range s.msgs {
921		if m.PublishTime.Before(target) {
922			continue
923		}
924		sub.msgs[m.ID] = &message{
925			publishTime: m.PublishTime,
926			proto: &pb.ReceivedMessage{
927				AckId: m.ID,
928				// This was not preserved!
929				//Message: pm,
930			},
931			deliveries:  &m.deliveries,
932			acks:        &m.acks,
933			streamIndex: -1,
934		}
935	}
936	return &pb.SeekResponse{}, nil
937}
938
939// Gets a subscription that must exist.
940// Must be called with the lock held.
941func (s *GServer) findSubscription(name string) (*subscription, error) {
942	if name == "" {
943		return nil, status.Errorf(codes.InvalidArgument, "missing subscription")
944	}
945	sub := s.subs[name]
946	if sub == nil {
947		return nil, status.Errorf(codes.NotFound, "subscription %s", name)
948	}
949	return sub, nil
950}
951
952// Must be called with the lock held.
953func (s *subscription) pull(max int) []*pb.ReceivedMessage {
954	now := s.timeNowFunc()
955	s.maintainMessages(now)
956	var msgs []*pb.ReceivedMessage
957	for _, m := range s.msgs {
958		if m.outstanding() {
959			continue
960		}
961		(*m.deliveries)++
962		m.ackDeadline = now.Add(s.ackTimeout)
963		msgs = append(msgs, m.proto)
964		if len(msgs) >= max {
965			break
966		}
967	}
968	return msgs
969}
970
971func (s *subscription) deliver() {
972	s.mu.Lock()
973	defer s.mu.Unlock()
974
975	now := s.timeNowFunc()
976	s.maintainMessages(now)
977	// Try to deliver each remaining message.
978	curIndex := 0
979	for _, m := range s.msgs {
980		if m.outstanding() {
981			continue
982		}
983		// If the message was never delivered before, start with the stream at
984		// curIndex. If it was delivered before, start with the stream after the one
985		// that owned it.
986		if m.streamIndex < 0 {
987			delIndex, ok := s.tryDeliverMessage(m, curIndex, now)
988			if !ok {
989				break
990			}
991			curIndex = delIndex + 1
992			m.streamIndex = curIndex
993		} else {
994			delIndex, ok := s.tryDeliverMessage(m, m.streamIndex, now)
995			if !ok {
996				break
997			}
998			m.streamIndex = delIndex
999		}
1000	}
1001}
1002
1003// tryDeliverMessage attempts to deliver m to the stream at index i. If it can't, it
1004// tries streams i+1, i+2, ..., wrapping around. Once it's tried all streams, it
1005// exits.
1006//
1007// It returns the index of the stream it delivered the message to, or 0, false if
1008// it didn't deliver the message.
1009//
1010// Must be called with the lock held.
1011func (s *subscription) tryDeliverMessage(m *message, start int, now time.Time) (int, bool) {
1012	for i := 0; i < len(s.streams); i++ {
1013		idx := (i + start) % len(s.streams)
1014
1015		st := s.streams[idx]
1016		select {
1017		case <-st.done:
1018			s.streams = deleteStreamAt(s.streams, idx)
1019			i--
1020
1021		case st.msgc <- m.proto:
1022			(*m.deliveries)++
1023			m.ackDeadline = now.Add(st.ackTimeout)
1024			return idx, true
1025
1026		default:
1027		}
1028	}
1029	return 0, false
1030}
1031
1032var retentionDuration = 10 * time.Minute
1033
1034// Must be called with the lock held.
1035func (s *subscription) maintainMessages(now time.Time) {
1036	for id, m := range s.msgs {
1037		// Mark a message as re-deliverable if its ack deadline has expired.
1038		if m.outstanding() && now.After(m.ackDeadline) {
1039			m.makeAvailable()
1040		}
1041		pubTime := m.proto.Message.PublishTime.AsTime()
1042		// Remove messages that have been undelivered for a long time.
1043		if !m.outstanding() && now.Sub(pubTime) > retentionDuration {
1044			delete(s.msgs, id)
1045		}
1046	}
1047}
1048
1049func (s *subscription) newStream(gs pb.Subscriber_StreamingPullServer, timeout time.Duration) *stream {
1050	st := &stream{
1051		sub:        s,
1052		done:       make(chan struct{}),
1053		msgc:       make(chan *pb.ReceivedMessage),
1054		gstream:    gs,
1055		ackTimeout: s.ackTimeout,
1056		timeout:    timeout,
1057	}
1058	s.mu.Lock()
1059	s.streams = append(s.streams, st)
1060	s.mu.Unlock()
1061	return st
1062}
1063
1064func (s *subscription) deleteStream(st *stream) {
1065	s.mu.Lock()
1066	defer s.mu.Unlock()
1067	var i int
1068	for i = 0; i < len(s.streams); i++ {
1069		if s.streams[i] == st {
1070			break
1071		}
1072	}
1073	if i < len(s.streams) {
1074		s.streams = deleteStreamAt(s.streams, i)
1075	}
1076}
1077func deleteStreamAt(s []*stream, i int) []*stream {
1078	// Preserve order for round-robin delivery.
1079	return append(s[:i], s[i+1:]...)
1080}
1081
1082type message struct {
1083	proto       *pb.ReceivedMessage
1084	publishTime time.Time
1085	ackDeadline time.Time
1086	deliveries  *int
1087	acks        *int
1088	streamIndex int // index of stream that currently owns msg, for round-robin delivery
1089}
1090
1091// A message is outstanding if it is owned by some stream.
1092func (m *message) outstanding() bool {
1093	return !m.ackDeadline.IsZero()
1094}
1095
1096func (m *message) makeAvailable() {
1097	m.ackDeadline = time.Time{}
1098}
1099
1100type stream struct {
1101	sub        *subscription
1102	done       chan struct{} // closed when the stream is finished
1103	msgc       chan *pb.ReceivedMessage
1104	gstream    pb.Subscriber_StreamingPullServer
1105	ackTimeout time.Duration
1106	timeout    time.Duration
1107}
1108
1109// pull manages the StreamingPull interaction for the life of the stream.
1110func (st *stream) pull(wg *sync.WaitGroup) error {
1111	errc := make(chan error, 2)
1112	wg.Add(2)
1113	go func() {
1114		defer wg.Done()
1115		errc <- st.sendLoop()
1116	}()
1117	go func() {
1118		defer wg.Done()
1119		errc <- st.recvLoop()
1120	}()
1121	var tchan <-chan time.Time
1122	if st.timeout > 0 {
1123		tchan = time.After(st.timeout)
1124	}
1125	// Wait until one of the goroutines returns an error, or we time out.
1126	var err error
1127	select {
1128	case err = <-errc:
1129		if err == io.EOF {
1130			err = nil
1131		}
1132	case <-tchan:
1133	}
1134	close(st.done) // stop the other goroutine
1135	return err
1136}
1137
1138func (st *stream) sendLoop() error {
1139	for {
1140		select {
1141		case <-st.done:
1142			return nil
1143		case rm := <-st.msgc:
1144			res := &pb.StreamingPullResponse{ReceivedMessages: []*pb.ReceivedMessage{rm}}
1145			if err := st.gstream.Send(res); err != nil {
1146				return err
1147			}
1148		}
1149	}
1150}
1151
1152func (st *stream) recvLoop() error {
1153	for {
1154		req, err := st.gstream.Recv()
1155		if err != nil {
1156			return err
1157		}
1158		st.sub.handleStreamingPullRequest(st, req)
1159	}
1160}
1161
1162func (s *subscription) handleStreamingPullRequest(st *stream, req *pb.StreamingPullRequest) {
1163	// Lock the entire server.
1164	s.mu.Lock()
1165	defer s.mu.Unlock()
1166
1167	for _, ackID := range req.AckIds {
1168		s.ack(ackID)
1169	}
1170	for i, id := range req.ModifyDeadlineAckIds {
1171		s.modifyAckDeadline(id, secsToDur(req.ModifyDeadlineSeconds[i]))
1172	}
1173	if req.StreamAckDeadlineSeconds > 0 {
1174		st.ackTimeout = secsToDur(req.StreamAckDeadlineSeconds)
1175	}
1176}
1177
1178// Must be called with the lock held.
1179func (s *subscription) ack(id string) {
1180	m := s.msgs[id]
1181	if m != nil {
1182		(*m.acks)++
1183		delete(s.msgs, id)
1184	}
1185}
1186
1187// Must be called with the lock held.
1188func (s *subscription) modifyAckDeadline(id string, d time.Duration) {
1189	m := s.msgs[id]
1190	if m == nil { // already acked: ignore.
1191		return
1192	}
1193	if d == 0 { // nack
1194		m.makeAvailable()
1195	} else { // extend the deadline by d
1196		m.ackDeadline = s.timeNowFunc().Add(d)
1197	}
1198}
1199
1200func secsToDur(secs int32) time.Duration {
1201	return time.Duration(secs) * time.Second
1202}
1203
1204// runReactor looks up the reactors for a function, then launches them until handled=true
1205// or err is returned. If the reactor returns nil, the function returns defaultObj instead.
1206func (s *GServer) runReactor(req interface{}, funcName string, defaultObj interface{}) (bool, interface{}, error) {
1207	if val, ok := s.reactorOptions[funcName]; ok {
1208		for _, reactor := range val {
1209			handled, ret, err := reactor.React(req)
1210			// If handled=true, that means the reactor has successfully reacted to the request,
1211			// so use the output directly. If err occurs, that means the request is invalidated
1212			// by the reactor somehow.
1213			if handled || err != nil {
1214				if ret == nil {
1215					ret = defaultObj
1216				}
1217				return true, ret, err
1218			}
1219		}
1220	}
1221	return false, nil, nil
1222}
1223
1224// errorInjectionReactor is a reactor to inject an error message with status code.
1225type errorInjectionReactor struct {
1226	code codes.Code
1227	msg  string
1228}
1229
1230// React simply returns an error with defined error message and status code.
1231func (e *errorInjectionReactor) React(_ interface{}) (handled bool, ret interface{}, err error) {
1232	return true, nil, status.Errorf(e.code, e.msg)
1233}
1234
1235// WithErrorInjection creates a ServerReactorOption that injects error with defined status code and
1236// message for a certain function.
1237func WithErrorInjection(funcName string, code codes.Code, msg string) ServerReactorOption {
1238	return ServerReactorOption{
1239		FuncName: funcName,
1240		Reactor:  &errorInjectionReactor{code: code, msg: msg},
1241	}
1242}
1243
1244func (s *GServer) CreateSchema(_ context.Context, req *pb.CreateSchemaRequest) (*pb.Schema, error) {
1245	s.mu.Lock()
1246	defer s.mu.Unlock()
1247
1248	if handled, ret, err := s.runReactor(req, "CreateSchema", &pb.Schema{}); handled || err != nil {
1249		return ret.(*pb.Schema), err
1250	}
1251
1252	name := fmt.Sprintf("%s/schemas/%s", req.Parent, req.SchemaId)
1253	sc := &pb.Schema{
1254		Name:       name,
1255		Type:       req.Schema.Type,
1256		Definition: req.Schema.Definition,
1257	}
1258	s.schemas[name] = sc
1259
1260	return sc, nil
1261}
1262
1263func (s *GServer) GetSchema(_ context.Context, req *pb.GetSchemaRequest) (*pb.Schema, error) {
1264
1265	s.mu.Lock()
1266	defer s.mu.Unlock()
1267
1268	if handled, ret, err := s.runReactor(req, "GetSchema", &pb.Schema{}); handled || err != nil {
1269		return ret.(*pb.Schema), err
1270	}
1271
1272	sc, ok := s.schemas[req.Name]
1273	if !ok {
1274		return nil, status.Errorf(codes.NotFound, "schema(%q) not found", req.Name)
1275	}
1276	return sc, nil
1277}
1278
1279func (s *GServer) ListSchemas(_ context.Context, req *pb.ListSchemasRequest) (*pb.ListSchemasResponse, error) {
1280	s.mu.Lock()
1281	defer s.mu.Unlock()
1282
1283	if handled, ret, err := s.runReactor(req, "ListSchemas", &pb.ListSchemasResponse{}); handled || err != nil {
1284		return ret.(*pb.ListSchemasResponse), err
1285	}
1286	ss := make([]*pb.Schema, 0)
1287	for _, sc := range s.schemas {
1288		ss = append(ss, sc)
1289	}
1290	return &pb.ListSchemasResponse{
1291		Schemas: ss,
1292	}, nil
1293}
1294
1295func (s *GServer) DeleteSchema(_ context.Context, req *pb.DeleteSchemaRequest) (*emptypb.Empty, error) {
1296	s.mu.Lock()
1297	defer s.mu.Unlock()
1298
1299	if handled, ret, err := s.runReactor(req, "DeleteSchema", &emptypb.Empty{}); handled || err != nil {
1300		return ret.(*emptypb.Empty), err
1301	}
1302
1303	schema := s.schemas[req.Name]
1304	if schema == nil {
1305		return nil, status.Errorf(codes.NotFound, "schema %q", req.Name)
1306	}
1307
1308	delete(s.schemas, req.Name)
1309	return &emptypb.Empty{}, nil
1310}
1311
1312// ValidateSchema mocks the ValidateSchema call but only checks that the schema definition is not empty.
1313func (s *GServer) ValidateSchema(_ context.Context, req *pb.ValidateSchemaRequest) (*pb.ValidateSchemaResponse, error) {
1314	s.mu.Lock()
1315	defer s.mu.Unlock()
1316
1317	if handled, ret, err := s.runReactor(req, "ValidateSchema", &pb.ValidateSchemaResponse{}); handled || err != nil {
1318		return ret.(*pb.ValidateSchemaResponse), err
1319	}
1320
1321	if req.Schema.Definition == "" {
1322		return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1323	}
1324	return &pb.ValidateSchemaResponse{}, nil
1325}
1326
1327// ValidateMessage mocks the ValidateMessage call but only checks that the schema definition to validate the
1328// message against is not empty.
1329func (s *GServer) ValidateMessage(_ context.Context, req *pb.ValidateMessageRequest) (*pb.ValidateMessageResponse, error) {
1330	s.mu.Lock()
1331	defer s.mu.Unlock()
1332
1333	if handled, ret, err := s.runReactor(req, "ValidateMessage", &pb.ValidateMessageResponse{}); handled || err != nil {
1334		return ret.(*pb.ValidateMessageResponse), err
1335	}
1336
1337	spec := req.GetSchemaSpec()
1338	if valReq, ok := spec.(*pb.ValidateMessageRequest_Name); ok {
1339		sc, ok := s.schemas[valReq.Name]
1340		if !ok {
1341			return nil, status.Errorf(codes.NotFound, "schema(%q) not found", valReq.Name)
1342		}
1343		if sc.Definition == "" {
1344			return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1345		}
1346	}
1347	if valReq, ok := spec.(*pb.ValidateMessageRequest_Schema); ok {
1348		if valReq.Schema.Definition == "" {
1349			return nil, status.Error(codes.InvalidArgument, "schema definition cannot be empty")
1350		}
1351	}
1352
1353	return &pb.ValidateMessageResponse{}, nil
1354}
1355