1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13 14package test 15 16import ( 17 "context" 18 "fmt" 19 "io" 20 "reflect" 21 "sync" 22 23 "cloud.google.com/go/internal/testutil" 24 "google.golang.org/grpc" 25 "google.golang.org/grpc/codes" 26 "google.golang.org/grpc/status" 27 28 pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1" 29) 30 31// Server is a mock Pub/Sub Lite server that can be used for unit testing. 32type Server struct { 33 LiteServer MockServer 34 gRPCServer *testutil.Server 35} 36 37// MockServer is an in-memory mock implementation of a Pub/Sub Lite service, 38// which allows unit tests to inspect requests received by the server and send 39// fake responses. 40// This is the interface that should be used by tests. 41type MockServer interface { 42 // OnTestStart must be called at the start of each test to clear any existing 43 // state and set the verifier for unary RPCs. 44 OnTestStart(globalVerifier *RPCVerifier) 45 // OnTestEnd should be called at the end of each test to flush the verifiers 46 // (i.e. check whether any expected requests were not sent to the server). 47 OnTestEnd() 48 // AddPublishStream adds a verifier for a publish stream of a topic partition. 49 AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) 50 // AddSubscribeStream adds a verifier for a subscribe stream of a partition. 51 AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) 52 // AddCommitStream adds a verifier for a commit stream of a partition. 53 AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) 54 // AddAssignmentStream adds a verifier for a partition assignment stream for a 55 // subscription. 56 AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) 57} 58 59// NewServer creates a new mock Pub/Sub Lite server. 60func NewServer() (*Server, error) { 61 srv, err := testutil.NewServer() 62 if err != nil { 63 return nil, err 64 } 65 liteServer := newMockLiteServer() 66 pb.RegisterAdminServiceServer(srv.Gsrv, liteServer) 67 pb.RegisterPublisherServiceServer(srv.Gsrv, liteServer) 68 pb.RegisterSubscriberServiceServer(srv.Gsrv, liteServer) 69 pb.RegisterCursorServiceServer(srv.Gsrv, liteServer) 70 pb.RegisterPartitionAssignmentServiceServer(srv.Gsrv, liteServer) 71 srv.Start() 72 return &Server{LiteServer: liteServer, gRPCServer: srv}, nil 73} 74 75// Addr returns the address that the server is listening on. 76func (s *Server) Addr() string { 77 return s.gRPCServer.Addr 78} 79 80// Close shuts down the server and releases all resources. 81func (s *Server) Close() { 82 s.gRPCServer.Close() 83} 84 85type streamHolder struct { 86 stream grpc.ServerStream 87 verifier *RPCVerifier 88} 89 90// mockLiteServer implements the MockServer interface. 91type mockLiteServer struct { 92 pb.AdminServiceServer 93 pb.PublisherServiceServer 94 pb.SubscriberServiceServer 95 pb.CursorServiceServer 96 pb.PartitionAssignmentServiceServer 97 98 mu sync.Mutex 99 100 // Global list of verifiers for all unary RPCs. This should be set before the 101 // test begins. 102 globalVerifier *RPCVerifier 103 104 // Stream verifiers by key. 105 publishVerifiers *keyedStreamVerifiers 106 subscribeVerifiers *keyedStreamVerifiers 107 commitVerifiers *keyedStreamVerifiers 108 assignmentVerifiers *keyedStreamVerifiers 109 110 nextStreamID int 111 activeStreams map[int]*streamHolder 112 testActive bool 113} 114 115func key(path string, partition int) string { 116 return fmt.Sprintf("%s:%d", path, partition) 117} 118 119func newMockLiteServer() *mockLiteServer { 120 return &mockLiteServer{ 121 publishVerifiers: newKeyedStreamVerifiers(), 122 subscribeVerifiers: newKeyedStreamVerifiers(), 123 commitVerifiers: newKeyedStreamVerifiers(), 124 assignmentVerifiers: newKeyedStreamVerifiers(), 125 activeStreams: make(map[int]*streamHolder), 126 } 127} 128 129func (s *mockLiteServer) startStream(stream grpc.ServerStream, verifier *RPCVerifier) (id int) { 130 s.mu.Lock() 131 defer s.mu.Unlock() 132 133 id = s.nextStreamID 134 s.nextStreamID++ 135 s.activeStreams[id] = &streamHolder{stream: stream, verifier: verifier} 136 return 137} 138 139func (s *mockLiteServer) endStream(id int) { 140 s.mu.Lock() 141 defer s.mu.Unlock() 142 143 delete(s.activeStreams, id) 144} 145 146func (s *mockLiteServer) popStreamVerifier(key string, keyedVerifiers *keyedStreamVerifiers) (*RPCVerifier, error) { 147 s.mu.Lock() 148 defer s.mu.Unlock() 149 150 return keyedVerifiers.Pop(key) 151} 152 153func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string, keyedVerifiers *keyedStreamVerifiers) (err error) { 154 verifier, err := s.popStreamVerifier(key, keyedVerifiers) 155 if err != nil { 156 return err 157 } 158 159 id := s.startStream(stream, verifier) 160 161 // Verify initial request. 162 retResponse, retErr := verifier.Pop(req) 163 var ok bool 164 165 for { 166 if retErr != nil { 167 err = retErr 168 break 169 } 170 if err = stream.SendMsg(retResponse); err != nil { 171 err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err) 172 break 173 } 174 175 // Check whether the next response isn't blocked on a request. 176 ok, retResponse, retErr = verifier.TryPop() 177 if ok { 178 continue 179 } 180 181 req = reflect.New(requestType).Interface() 182 if err = stream.RecvMsg(req); err == io.EOF { 183 break 184 } else if err != nil { 185 err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err) 186 break 187 } 188 retResponse, retErr = verifier.Pop(req) 189 } 190 191 // Check whether the stream ended prematurely. 192 verifier.Flush() 193 s.endStream(id) 194 return 195} 196 197// MockServer implementation. 198 199func (s *mockLiteServer) OnTestStart(globalVerifier *RPCVerifier) { 200 s.mu.Lock() 201 defer s.mu.Unlock() 202 203 if s.testActive { 204 panic("mockserver is already in use by another test") 205 } 206 207 s.testActive = true 208 s.globalVerifier = globalVerifier 209 s.publishVerifiers.Reset() 210 s.subscribeVerifiers.Reset() 211 s.commitVerifiers.Reset() 212 s.assignmentVerifiers.Reset() 213 s.activeStreams = make(map[int]*streamHolder) 214} 215 216func (s *mockLiteServer) OnTestEnd() { 217 s.mu.Lock() 218 defer s.mu.Unlock() 219 220 s.testActive = false 221 if s.globalVerifier != nil { 222 s.globalVerifier.Flush() 223 } 224 225 for _, as := range s.activeStreams { 226 as.verifier.Flush() 227 } 228} 229 230func (s *mockLiteServer) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) { 231 s.mu.Lock() 232 defer s.mu.Unlock() 233 s.publishVerifiers.Push(key(topic, partition), streamVerifier) 234} 235 236func (s *mockLiteServer) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) { 237 s.mu.Lock() 238 defer s.mu.Unlock() 239 s.subscribeVerifiers.Push(key(subscription, partition), streamVerifier) 240} 241 242func (s *mockLiteServer) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) { 243 s.mu.Lock() 244 defer s.mu.Unlock() 245 s.commitVerifiers.Push(key(subscription, partition), streamVerifier) 246} 247 248func (s *mockLiteServer) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) { 249 s.mu.Lock() 250 defer s.mu.Unlock() 251 s.assignmentVerifiers.Push(subscription, streamVerifier) 252} 253 254// PublisherService implementation. 255 256func (s *mockLiteServer) Publish(stream pb.PublisherService_PublishServer) error { 257 req, err := stream.Recv() 258 if err != nil { 259 return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err) 260 } 261 if len(req.GetInitialRequest().GetTopic()) == 0 { 262 return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial publish request: %v", req) 263 } 264 265 initReq := req.GetInitialRequest() 266 k := key(initReq.GetTopic(), int(initReq.GetPartition())) 267 return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k, s.publishVerifiers) 268} 269 270// SubscriberService implementation. 271 272func (s *mockLiteServer) Subscribe(stream pb.SubscriberService_SubscribeServer) error { 273 req, err := stream.Recv() 274 if err != nil { 275 return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err) 276 } 277 if len(req.GetInitial().GetSubscription()) == 0 { 278 return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial subscribe request: %v", req) 279 } 280 281 initReq := req.GetInitial() 282 k := key(initReq.GetSubscription(), int(initReq.GetPartition())) 283 return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k, s.subscribeVerifiers) 284} 285 286// CursorService implementation. 287 288func (s *mockLiteServer) StreamingCommitCursor(stream pb.CursorService_StreamingCommitCursorServer) error { 289 req, err := stream.Recv() 290 if err != nil { 291 return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err) 292 } 293 if len(req.GetInitial().GetSubscription()) == 0 { 294 return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial streaming commit cursor request: %v", req) 295 } 296 297 initReq := req.GetInitial() 298 k := key(initReq.GetSubscription(), int(initReq.GetPartition())) 299 return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k, s.commitVerifiers) 300} 301 302// PartitionAssignmentService implementation. 303 304func (s *mockLiteServer) AssignPartitions(stream pb.PartitionAssignmentService_AssignPartitionsServer) error { 305 req, err := stream.Recv() 306 if err != nil { 307 return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err) 308 } 309 if len(req.GetInitial().GetSubscription()) == 0 { 310 return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial partition assignment request: %v", req) 311 } 312 313 k := req.GetInitial().GetSubscription() 314 return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k, s.assignmentVerifiers) 315} 316 317// AdminService implementation. 318 319func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) { 320 s.mu.Lock() 321 defer s.mu.Unlock() 322 323 retResponse, retErr := s.globalVerifier.Pop(req) 324 if retErr != nil { 325 return nil, retErr 326 } 327 resp, ok := retResponse.(*pb.TopicPartitions) 328 if !ok { 329 return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse)) 330 } 331 return resp, nil 332} 333