1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package pubsub
16
17// This file provides a mock in-memory pubsub server for streaming pull testing.
18
19import (
20	"context"
21	"io"
22	"sync"
23	"time"
24
25	"cloud.google.com/go/internal/testutil"
26	pb "google.golang.org/genproto/googleapis/pubsub/v1"
27	"google.golang.org/protobuf/types/known/emptypb"
28)
29
30type mockServer struct {
31	srv *testutil.Server
32
33	pb.SubscriberServer
34
35	Addr string
36
37	mu            sync.Mutex
38	Acked         map[string]bool  // acked message IDs
39	Deadlines     map[string]int32 // deadlines by message ID
40	pullResponses []*pullResponse
41	ackErrs       []error
42	modAckErrs    []error
43	wg            sync.WaitGroup
44	sub           *pb.Subscription
45}
46
47type pullResponse struct {
48	msgs []*pb.ReceivedMessage
49	err  error
50}
51
52func newMockServer(port int) (*mockServer, error) {
53	srv, err := testutil.NewServerWithPort(port)
54	if err != nil {
55		return nil, err
56	}
57	mock := &mockServer{
58		srv:       srv,
59		Addr:      srv.Addr,
60		Acked:     map[string]bool{},
61		Deadlines: map[string]int32{},
62		sub: &pb.Subscription{
63			AckDeadlineSeconds: 10,
64			PushConfig:         &pb.PushConfig{},
65		},
66	}
67	pb.RegisterSubscriberServer(srv.Gsrv, mock)
68	srv.Start()
69	return mock, nil
70}
71
72// Each call to addStreamingPullMessages results in one StreamingPullResponse.
73func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) {
74	s.mu.Lock()
75	s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil})
76	s.mu.Unlock()
77}
78
79func (s *mockServer) addStreamingPullError(err error) {
80	s.mu.Lock()
81	s.pullResponses = append(s.pullResponses, &pullResponse{nil, err})
82	s.mu.Unlock()
83}
84
85func (s *mockServer) addAckResponse(err error) {
86	s.mu.Lock()
87	s.ackErrs = append(s.ackErrs, err)
88	s.mu.Unlock()
89}
90
91func (s *mockServer) addModAckResponse(err error) {
92	s.mu.Lock()
93	s.modAckErrs = append(s.modAckErrs, err)
94	s.mu.Unlock()
95}
96
97func (s *mockServer) wait() {
98	s.wg.Wait()
99}
100
101func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error {
102	s.wg.Add(1)
103	defer s.wg.Done()
104	errc := make(chan error, 1)
105	s.wg.Add(1)
106	go func() {
107		defer s.wg.Done()
108		for {
109			req, err := stream.Recv()
110			if err != nil {
111				errc <- err
112				return
113			}
114			s.mu.Lock()
115			for _, id := range req.AckIds {
116				s.Acked[id] = true
117			}
118			for i, id := range req.ModifyDeadlineAckIds {
119				s.Deadlines[id] = req.ModifyDeadlineSeconds[i]
120			}
121			s.mu.Unlock()
122		}
123	}()
124	// Send responses.
125	for {
126		s.mu.Lock()
127		if len(s.pullResponses) == 0 {
128			s.mu.Unlock()
129			// Nothing to send, so wait for the client to shut down the stream.
130			err := <-errc // a real error, or at least EOF
131			if err == io.EOF {
132				return nil
133			}
134			return err
135		}
136		pr := s.pullResponses[0]
137		s.pullResponses = s.pullResponses[1:]
138		s.mu.Unlock()
139		if pr.err != nil {
140			// Add a slight delay to ensure the server receives any
141			// messages en route from the client before shutting down the stream.
142			// This reduces flakiness of tests involving retry.
143			time.Sleep(200 * time.Millisecond)
144		}
145		if pr.err == io.EOF {
146			return nil
147		}
148		if pr.err != nil {
149			return pr.err
150		}
151		// Return any error from Recv.
152		select {
153		case err := <-errc:
154			return err
155		default:
156		}
157		res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs}
158		if err := stream.Send(res); err != nil {
159			return err
160		}
161	}
162}
163
164func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) {
165	var err error
166	s.mu.Lock()
167	if len(s.ackErrs) > 0 {
168		err = s.ackErrs[0]
169		s.ackErrs = s.ackErrs[1:]
170	}
171	if err != nil {
172		s.mu.Unlock()
173		return nil, err
174	}
175	for _, id := range req.AckIds {
176		s.Acked[id] = true
177	}
178	s.mu.Unlock()
179	return &emptypb.Empty{}, nil
180}
181
182func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) {
183	var err error
184	s.mu.Lock()
185	if len(s.modAckErrs) > 0 {
186		err = s.modAckErrs[0]
187		s.modAckErrs = s.modAckErrs[1:]
188	}
189	if err != nil {
190		s.mu.Unlock()
191		return nil, err
192	}
193	for _, id := range req.AckIds {
194		s.Deadlines[id] = req.AckDeadlineSeconds
195	}
196	s.mu.Unlock()
197	return &emptypb.Empty{}, nil
198}
199
200func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
201	return s.sub, nil
202}
203