1// Copyright 2017 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package pubsub 16 17// This file provides a mock in-memory pubsub server for streaming pull testing. 18 19import ( 20 "context" 21 "io" 22 "sync" 23 "time" 24 25 "cloud.google.com/go/internal/testutil" 26 pb "google.golang.org/genproto/googleapis/pubsub/v1" 27 "google.golang.org/protobuf/types/known/emptypb" 28) 29 30type mockServer struct { 31 srv *testutil.Server 32 33 pb.SubscriberServer 34 35 Addr string 36 37 mu sync.Mutex 38 Acked map[string]bool // acked message IDs 39 Deadlines map[string]int32 // deadlines by message ID 40 pullResponses []*pullResponse 41 ackErrs []error 42 modAckErrs []error 43 wg sync.WaitGroup 44 sub *pb.Subscription 45} 46 47type pullResponse struct { 48 msgs []*pb.ReceivedMessage 49 err error 50} 51 52func newMockServer(port int) (*mockServer, error) { 53 srv, err := testutil.NewServerWithPort(port) 54 if err != nil { 55 return nil, err 56 } 57 mock := &mockServer{ 58 srv: srv, 59 Addr: srv.Addr, 60 Acked: map[string]bool{}, 61 Deadlines: map[string]int32{}, 62 sub: &pb.Subscription{ 63 AckDeadlineSeconds: 10, 64 PushConfig: &pb.PushConfig{}, 65 }, 66 } 67 pb.RegisterSubscriberServer(srv.Gsrv, mock) 68 srv.Start() 69 return mock, nil 70} 71 72// Each call to addStreamingPullMessages results in one StreamingPullResponse. 73func (s *mockServer) addStreamingPullMessages(msgs []*pb.ReceivedMessage) { 74 s.mu.Lock() 75 s.pullResponses = append(s.pullResponses, &pullResponse{msgs, nil}) 76 s.mu.Unlock() 77} 78 79func (s *mockServer) addStreamingPullError(err error) { 80 s.mu.Lock() 81 s.pullResponses = append(s.pullResponses, &pullResponse{nil, err}) 82 s.mu.Unlock() 83} 84 85func (s *mockServer) addAckResponse(err error) { 86 s.mu.Lock() 87 s.ackErrs = append(s.ackErrs, err) 88 s.mu.Unlock() 89} 90 91func (s *mockServer) addModAckResponse(err error) { 92 s.mu.Lock() 93 s.modAckErrs = append(s.modAckErrs, err) 94 s.mu.Unlock() 95} 96 97func (s *mockServer) wait() { 98 s.wg.Wait() 99} 100 101func (s *mockServer) StreamingPull(stream pb.Subscriber_StreamingPullServer) error { 102 s.wg.Add(1) 103 defer s.wg.Done() 104 errc := make(chan error, 1) 105 s.wg.Add(1) 106 go func() { 107 defer s.wg.Done() 108 for { 109 req, err := stream.Recv() 110 if err != nil { 111 errc <- err 112 return 113 } 114 s.mu.Lock() 115 for _, id := range req.AckIds { 116 s.Acked[id] = true 117 } 118 for i, id := range req.ModifyDeadlineAckIds { 119 s.Deadlines[id] = req.ModifyDeadlineSeconds[i] 120 } 121 s.mu.Unlock() 122 } 123 }() 124 // Send responses. 125 for { 126 s.mu.Lock() 127 if len(s.pullResponses) == 0 { 128 s.mu.Unlock() 129 // Nothing to send, so wait for the client to shut down the stream. 130 err := <-errc // a real error, or at least EOF 131 if err == io.EOF { 132 return nil 133 } 134 return err 135 } 136 pr := s.pullResponses[0] 137 s.pullResponses = s.pullResponses[1:] 138 s.mu.Unlock() 139 if pr.err != nil { 140 // Add a slight delay to ensure the server receives any 141 // messages en route from the client before shutting down the stream. 142 // This reduces flakiness of tests involving retry. 143 time.Sleep(200 * time.Millisecond) 144 } 145 if pr.err == io.EOF { 146 return nil 147 } 148 if pr.err != nil { 149 return pr.err 150 } 151 // Return any error from Recv. 152 select { 153 case err := <-errc: 154 return err 155 default: 156 } 157 res := &pb.StreamingPullResponse{ReceivedMessages: pr.msgs} 158 if err := stream.Send(res); err != nil { 159 return err 160 } 161 } 162} 163 164func (s *mockServer) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*emptypb.Empty, error) { 165 var err error 166 s.mu.Lock() 167 if len(s.ackErrs) > 0 { 168 err = s.ackErrs[0] 169 s.ackErrs = s.ackErrs[1:] 170 } 171 if err != nil { 172 s.mu.Unlock() 173 return nil, err 174 } 175 for _, id := range req.AckIds { 176 s.Acked[id] = true 177 } 178 s.mu.Unlock() 179 return &emptypb.Empty{}, nil 180} 181 182func (s *mockServer) ModifyAckDeadline(ctx context.Context, req *pb.ModifyAckDeadlineRequest) (*emptypb.Empty, error) { 183 var err error 184 s.mu.Lock() 185 if len(s.modAckErrs) > 0 { 186 err = s.modAckErrs[0] 187 s.modAckErrs = s.modAckErrs[1:] 188 } 189 if err != nil { 190 s.mu.Unlock() 191 return nil, err 192 } 193 for _, id := range req.AckIds { 194 s.Deadlines[id] = req.AckDeadlineSeconds 195 } 196 s.mu.Unlock() 197 return &emptypb.Empty{}, nil 198} 199 200func (s *mockServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) { 201 return s.sub, nil 202} 203