1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package test
15
16import (
17	"context"
18	"fmt"
19	"io"
20	"reflect"
21	"sync"
22
23	"cloud.google.com/go/internal/testutil"
24	"google.golang.org/grpc"
25	"google.golang.org/grpc/codes"
26	"google.golang.org/grpc/status"
27
28	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
29)
30
31// Server is a mock Pub/Sub Lite server that can be used for unit testing.
32type Server struct {
33	LiteServer MockServer
34	gRPCServer *testutil.Server
35}
36
37// MockServer is an in-memory mock implementation of a Pub/Sub Lite service,
38// which allows unit tests to inspect requests received by the server and send
39// fake responses.
40// This is the interface that should be used by tests.
41type MockServer interface {
42	// OnTestStart must be called at the start of each test to clear any existing
43	// state and set the verifier for unary RPCs.
44	OnTestStart(globalVerifier *RPCVerifier)
45	// OnTestEnd should be called at the end of each test to flush the verifiers
46	// (i.e. check whether any expected requests were not sent to the server).
47	OnTestEnd()
48	// AddPublishStream adds a verifier for a publish stream of a topic partition.
49	AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier)
50	// AddSubscribeStream adds a verifier for a subscribe stream of a partition.
51	AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier)
52	// AddCommitStream adds a verifier for a commit stream of a partition.
53	AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier)
54	// AddAssignmentStream adds a verifier for a partition assignment stream for a
55	// subscription.
56	AddAssignmentStream(subscription string, streamVerifier *RPCVerifier)
57}
58
59// NewServer creates a new mock Pub/Sub Lite server.
60func NewServer() (*Server, error) {
61	srv, err := testutil.NewServer()
62	if err != nil {
63		return nil, err
64	}
65	liteServer := newMockLiteServer()
66	pb.RegisterAdminServiceServer(srv.Gsrv, liteServer)
67	pb.RegisterPublisherServiceServer(srv.Gsrv, liteServer)
68	pb.RegisterSubscriberServiceServer(srv.Gsrv, liteServer)
69	pb.RegisterCursorServiceServer(srv.Gsrv, liteServer)
70	pb.RegisterPartitionAssignmentServiceServer(srv.Gsrv, liteServer)
71	srv.Start()
72	return &Server{LiteServer: liteServer, gRPCServer: srv}, nil
73}
74
75// Addr returns the address that the server is listening on.
76func (s *Server) Addr() string {
77	return s.gRPCServer.Addr
78}
79
80// Close shuts down the server and releases all resources.
81func (s *Server) Close() {
82	s.gRPCServer.Close()
83}
84
85type streamHolder struct {
86	stream   grpc.ServerStream
87	verifier *RPCVerifier
88}
89
90// mockLiteServer implements the MockServer interface.
91type mockLiteServer struct {
92	pb.AdminServiceServer
93	pb.PublisherServiceServer
94	pb.SubscriberServiceServer
95	pb.CursorServiceServer
96	pb.PartitionAssignmentServiceServer
97
98	mu sync.Mutex
99
100	// Global list of verifiers for all unary RPCs. This should be set before the
101	// test begins.
102	globalVerifier *RPCVerifier
103
104	// Stream verifiers by key.
105	publishVerifiers    *keyedStreamVerifiers
106	subscribeVerifiers  *keyedStreamVerifiers
107	commitVerifiers     *keyedStreamVerifiers
108	assignmentVerifiers *keyedStreamVerifiers
109
110	nextStreamID  int
111	activeStreams map[int]*streamHolder
112	testActive    bool
113}
114
115func key(path string, partition int) string {
116	return fmt.Sprintf("%s:%d", path, partition)
117}
118
119func newMockLiteServer() *mockLiteServer {
120	return &mockLiteServer{
121		publishVerifiers:    newKeyedStreamVerifiers(),
122		subscribeVerifiers:  newKeyedStreamVerifiers(),
123		commitVerifiers:     newKeyedStreamVerifiers(),
124		assignmentVerifiers: newKeyedStreamVerifiers(),
125		activeStreams:       make(map[int]*streamHolder),
126	}
127}
128
129func (s *mockLiteServer) startStream(stream grpc.ServerStream, verifier *RPCVerifier) (id int) {
130	s.mu.Lock()
131	defer s.mu.Unlock()
132
133	id = s.nextStreamID
134	s.nextStreamID++
135	s.activeStreams[id] = &streamHolder{stream: stream, verifier: verifier}
136	return
137}
138
139func (s *mockLiteServer) endStream(id int) {
140	s.mu.Lock()
141	defer s.mu.Unlock()
142
143	delete(s.activeStreams, id)
144}
145
146func (s *mockLiteServer) popStreamVerifier(key string, keyedVerifiers *keyedStreamVerifiers) (*RPCVerifier, error) {
147	s.mu.Lock()
148	defer s.mu.Unlock()
149
150	return keyedVerifiers.Pop(key)
151}
152
153func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string, keyedVerifiers *keyedStreamVerifiers) (err error) {
154	verifier, err := s.popStreamVerifier(key, keyedVerifiers)
155	if err != nil {
156		return err
157	}
158
159	id := s.startStream(stream, verifier)
160
161	// Verify initial request.
162	retResponse, retErr := verifier.Pop(req)
163	var ok bool
164
165	for {
166		if retErr != nil {
167			err = retErr
168			break
169		}
170		if err = stream.SendMsg(retResponse); err != nil {
171			err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
172			break
173		}
174
175		// Check whether the next response isn't blocked on a request.
176		ok, retResponse, retErr = verifier.TryPop()
177		if ok {
178			continue
179		}
180
181		req = reflect.New(requestType).Interface()
182		if err = stream.RecvMsg(req); err == io.EOF {
183			break
184		} else if err != nil {
185			err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err)
186			break
187		}
188		retResponse, retErr = verifier.Pop(req)
189	}
190
191	// Check whether the stream ended prematurely.
192	verifier.Flush()
193	s.endStream(id)
194	return
195}
196
197// MockServer implementation.
198
199func (s *mockLiteServer) OnTestStart(globalVerifier *RPCVerifier) {
200	s.mu.Lock()
201	defer s.mu.Unlock()
202
203	if s.testActive {
204		panic("mockserver is already in use by another test")
205	}
206
207	s.testActive = true
208	s.globalVerifier = globalVerifier
209	s.publishVerifiers.Reset()
210	s.subscribeVerifiers.Reset()
211	s.commitVerifiers.Reset()
212	s.assignmentVerifiers.Reset()
213	s.activeStreams = make(map[int]*streamHolder)
214}
215
216func (s *mockLiteServer) OnTestEnd() {
217	s.mu.Lock()
218	defer s.mu.Unlock()
219
220	s.testActive = false
221	if s.globalVerifier != nil {
222		s.globalVerifier.Flush()
223	}
224
225	for _, as := range s.activeStreams {
226		as.verifier.Flush()
227	}
228}
229
230func (s *mockLiteServer) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
231	s.mu.Lock()
232	defer s.mu.Unlock()
233	s.publishVerifiers.Push(key(topic, partition), streamVerifier)
234}
235
236func (s *mockLiteServer) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
237	s.mu.Lock()
238	defer s.mu.Unlock()
239	s.subscribeVerifiers.Push(key(subscription, partition), streamVerifier)
240}
241
242func (s *mockLiteServer) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
243	s.mu.Lock()
244	defer s.mu.Unlock()
245	s.commitVerifiers.Push(key(subscription, partition), streamVerifier)
246}
247
248func (s *mockLiteServer) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
249	s.mu.Lock()
250	defer s.mu.Unlock()
251	s.assignmentVerifiers.Push(subscription, streamVerifier)
252}
253
254// PublisherService implementation.
255
256func (s *mockLiteServer) Publish(stream pb.PublisherService_PublishServer) error {
257	req, err := stream.Recv()
258	if err != nil {
259		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
260	}
261	if len(req.GetInitialRequest().GetTopic()) == 0 {
262		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial publish request: %v", req)
263	}
264
265	initReq := req.GetInitialRequest()
266	k := key(initReq.GetTopic(), int(initReq.GetPartition()))
267	return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k, s.publishVerifiers)
268}
269
270// SubscriberService implementation.
271
272func (s *mockLiteServer) Subscribe(stream pb.SubscriberService_SubscribeServer) error {
273	req, err := stream.Recv()
274	if err != nil {
275		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
276	}
277	if len(req.GetInitial().GetSubscription()) == 0 {
278		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial subscribe request: %v", req)
279	}
280
281	initReq := req.GetInitial()
282	k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
283	return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k, s.subscribeVerifiers)
284}
285
286// CursorService implementation.
287
288func (s *mockLiteServer) StreamingCommitCursor(stream pb.CursorService_StreamingCommitCursorServer) error {
289	req, err := stream.Recv()
290	if err != nil {
291		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
292	}
293	if len(req.GetInitial().GetSubscription()) == 0 {
294		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial streaming commit cursor request: %v", req)
295	}
296
297	initReq := req.GetInitial()
298	k := key(initReq.GetSubscription(), int(initReq.GetPartition()))
299	return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k, s.commitVerifiers)
300}
301
302// PartitionAssignmentService implementation.
303
304func (s *mockLiteServer) AssignPartitions(stream pb.PartitionAssignmentService_AssignPartitionsServer) error {
305	req, err := stream.Recv()
306	if err != nil {
307		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
308	}
309	if len(req.GetInitial().GetSubscription()) == 0 {
310		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial partition assignment request: %v", req)
311	}
312
313	k := req.GetInitial().GetSubscription()
314	return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k, s.assignmentVerifiers)
315}
316
317// AdminService implementation.
318
319func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) {
320	s.mu.Lock()
321	defer s.mu.Unlock()
322
323	retResponse, retErr := s.globalVerifier.Pop(req)
324	if retErr != nil {
325		return nil, retErr
326	}
327	resp, ok := retResponse.(*pb.TopicPartitions)
328	if !ok {
329		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
330	}
331	return resp, nil
332}
333