1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package test
15
16import (
17	"context"
18	"io"
19	"log"
20	"reflect"
21	"sync"
22
23	"cloud.google.com/go/internal/testutil"
24	"cloud.google.com/go/internal/uid"
25	"google.golang.org/api/option"
26	"google.golang.org/grpc"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/status"
29
30	emptypb "github.com/golang/protobuf/ptypes/empty"
31	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
32)
33
34// MockServer is an in-memory mock implementation of a Pub/Sub Lite service,
35// which allows unit tests to inspect requests received by the server and send
36// fake responses.
37// This is the interface that should be used by tests.
38type MockServer interface {
39	// OnTestStart must be called at the start of each test to clear any existing
40	// state and set the test verifiers.
41	OnTestStart(*Verifiers)
42	// OnTestEnd should be called at the end of each test to flush the verifiers
43	// (i.e. check whether any expected requests were not sent to the server).
44	OnTestEnd()
45}
46
47// Server is a mock Pub/Sub Lite server that can be used for unit testing.
48type Server struct {
49	LiteServer MockServer
50	gRPCServer *testutil.Server
51}
52
53// NewServer creates a new mock Pub/Sub Lite server.
54func NewServer() (*Server, error) {
55	srv, err := testutil.NewServer()
56	if err != nil {
57		return nil, err
58	}
59	liteServer := newMockLiteServer()
60	pb.RegisterAdminServiceServer(srv.Gsrv, liteServer)
61	pb.RegisterPublisherServiceServer(srv.Gsrv, liteServer)
62	pb.RegisterSubscriberServiceServer(srv.Gsrv, liteServer)
63	pb.RegisterCursorServiceServer(srv.Gsrv, liteServer)
64	pb.RegisterPartitionAssignmentServiceServer(srv.Gsrv, liteServer)
65	srv.Start()
66	return &Server{LiteServer: liteServer, gRPCServer: srv}, nil
67}
68
69// ClientConn creates a client connection to the gRPC test server.
70func (s *Server) ClientConn() option.ClientOption {
71	conn, err := grpc.Dial(s.gRPCServer.Addr, grpc.WithInsecure())
72	if err != nil {
73		log.Fatal(err)
74	}
75	return option.WithGRPCConn(conn)
76}
77
78// Close shuts down the server and releases all resources.
79func (s *Server) Close() {
80	s.gRPCServer.Close()
81}
82
83// mockLiteServer implements the MockServer interface.
84type mockLiteServer struct {
85	pb.AdminServiceServer
86	pb.PublisherServiceServer
87	pb.SubscriberServiceServer
88	pb.CursorServiceServer
89	pb.PartitionAssignmentServiceServer
90
91	mu sync.Mutex
92
93	testVerifiers *Verifiers
94	testIDs       *uid.Space
95	currentTestID string
96}
97
98func newMockLiteServer() *mockLiteServer {
99	return &mockLiteServer{
100		testIDs: uid.NewSpace("mockLiteServer", nil),
101	}
102}
103
104func (s *mockLiteServer) popGlobalVerifiers(request interface{}) (interface{}, error) {
105	s.mu.Lock()
106	defer s.mu.Unlock()
107
108	if s.testVerifiers == nil {
109		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
110	}
111	return s.testVerifiers.GlobalVerifier.Pop(request)
112}
113
114func (s *mockLiteServer) popStreamVerifier(key string) (*RPCVerifier, error) {
115	s.mu.Lock()
116	defer s.mu.Unlock()
117
118	if s.testVerifiers == nil {
119		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
120	}
121	return s.testVerifiers.streamVerifiers.Pop(key)
122}
123
124func (s *mockLiteServer) handleStream(stream grpc.ServerStream, req interface{}, requestType reflect.Type, key string) (err error) {
125	testID := s.currentTest()
126	if testID == "" {
127		return status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
128	}
129	verifier, err := s.popStreamVerifier(key)
130	if err != nil {
131		return err
132	}
133
134	// Verify initial request.
135	retResponse, retErr := verifier.Pop(req)
136	var ok bool
137
138	for {
139		// See comments for RPCVerifier.Push for valid stream request/response
140		// combinations.
141		if retErr != nil {
142			err = retErr
143			break
144		}
145		if retResponse != nil {
146			if err = stream.SendMsg(retResponse); err != nil {
147				err = status.Errorf(codes.FailedPrecondition, "mockserver: stream send error: %v", err)
148				break
149			}
150		}
151
152		// Check whether the next response isn't blocked on a request.
153		ok, retResponse, retErr = verifier.TryPop()
154		if ok {
155			continue
156		}
157
158		req = reflect.New(requestType).Interface()
159		if err = stream.RecvMsg(req); err == io.EOF {
160			break
161		} else if err != nil {
162			err = status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error: %v", err)
163			break
164		}
165		if testID != s.currentTest() {
166			err = status.Errorf(codes.FailedPrecondition, "mockserver: previous test has ended")
167			break
168		}
169		retResponse, retErr = verifier.Pop(req)
170	}
171
172	// Check whether the stream ended prematurely.
173	if testID == s.currentTest() {
174		verifier.Flush()
175	}
176	return
177}
178
179// MockServer implementation.
180
181func (s *mockLiteServer) OnTestStart(verifiers *Verifiers) {
182	s.mu.Lock()
183	defer s.mu.Unlock()
184
185	if s.currentTestID != "" {
186		panic("mockserver is already in use by another test")
187	}
188	s.currentTestID = s.testIDs.New()
189	s.testVerifiers = verifiers
190}
191
192func (s *mockLiteServer) OnTestEnd() {
193	s.mu.Lock()
194	defer s.mu.Unlock()
195
196	s.currentTestID = ""
197	if s.testVerifiers != nil {
198		s.testVerifiers.flush()
199	}
200}
201
202func (s *mockLiteServer) currentTest() string {
203	s.mu.Lock()
204	defer s.mu.Unlock()
205	return s.currentTestID
206}
207
208// PublisherService implementation.
209
210func (s *mockLiteServer) Publish(stream pb.PublisherService_PublishServer) error {
211	req, err := stream.Recv()
212	if err != nil {
213		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
214	}
215	if len(req.GetInitialRequest().GetTopic()) == 0 {
216		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial publish request: %v", req)
217	}
218
219	initReq := req.GetInitialRequest()
220	k := keyPartition(publishStreamType, initReq.GetTopic(), int(initReq.GetPartition()))
221	return s.handleStream(stream, req, reflect.TypeOf(pb.PublishRequest{}), k)
222}
223
224// SubscriberService implementation.
225
226func (s *mockLiteServer) Subscribe(stream pb.SubscriberService_SubscribeServer) error {
227	req, err := stream.Recv()
228	if err != nil {
229		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
230	}
231	if len(req.GetInitial().GetSubscription()) == 0 {
232		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial subscribe request: %v", req)
233	}
234
235	initReq := req.GetInitial()
236	k := keyPartition(subscribeStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
237	return s.handleStream(stream, req, reflect.TypeOf(pb.SubscribeRequest{}), k)
238}
239
240// CursorService implementation.
241
242func (s *mockLiteServer) StreamingCommitCursor(stream pb.CursorService_StreamingCommitCursorServer) error {
243	req, err := stream.Recv()
244	if err != nil {
245		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
246	}
247	if len(req.GetInitial().GetSubscription()) == 0 {
248		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial streaming commit cursor request: %v", req)
249	}
250
251	initReq := req.GetInitial()
252	k := keyPartition(commitStreamType, initReq.GetSubscription(), int(initReq.GetPartition()))
253	return s.handleStream(stream, req, reflect.TypeOf(pb.StreamingCommitCursorRequest{}), k)
254}
255
256// PartitionAssignmentService implementation.
257
258func (s *mockLiteServer) AssignPartitions(stream pb.PartitionAssignmentService_AssignPartitionsServer) error {
259	req, err := stream.Recv()
260	if err != nil {
261		return status.Errorf(codes.FailedPrecondition, "mockserver: stream recv error before initial request: %v", err)
262	}
263	if len(req.GetInitial().GetSubscription()) == 0 {
264		return status.Errorf(codes.InvalidArgument, "mockserver: received invalid initial partition assignment request: %v", req)
265	}
266
267	k := key(assignmentStreamType, req.GetInitial().GetSubscription())
268	return s.handleStream(stream, req, reflect.TypeOf(pb.PartitionAssignmentRequest{}), k)
269}
270
271// AdminService implementation.
272
273func (s *mockLiteServer) doTopicResponse(ctx context.Context, req interface{}) (*pb.Topic, error) {
274	retResponse, retErr := s.popGlobalVerifiers(req)
275	if retErr != nil {
276		return nil, retErr
277	}
278	resp, ok := retResponse.(*pb.Topic)
279	if !ok {
280		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
281	}
282	return resp, nil
283}
284
285func (s *mockLiteServer) doSubscriptionResponse(ctx context.Context, req interface{}) (*pb.Subscription, error) {
286	retResponse, retErr := s.popGlobalVerifiers(req)
287	if retErr != nil {
288		return nil, retErr
289	}
290	resp, ok := retResponse.(*pb.Subscription)
291	if !ok {
292		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
293	}
294	return resp, nil
295}
296
297func (s *mockLiteServer) doEmptyResponse(ctx context.Context, req interface{}) (*emptypb.Empty, error) {
298	retResponse, retErr := s.popGlobalVerifiers(req)
299	if retErr != nil {
300		return nil, retErr
301	}
302	resp, ok := retResponse.(*emptypb.Empty)
303	if !ok {
304		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
305	}
306	return resp, nil
307}
308
309func (s *mockLiteServer) CreateTopic(ctx context.Context, req *pb.CreateTopicRequest) (*pb.Topic, error) {
310	return s.doTopicResponse(ctx, req)
311}
312
313func (s *mockLiteServer) UpdateTopic(ctx context.Context, req *pb.UpdateTopicRequest) (*pb.Topic, error) {
314	return s.doTopicResponse(ctx, req)
315}
316
317func (s *mockLiteServer) GetTopic(ctx context.Context, req *pb.GetTopicRequest) (*pb.Topic, error) {
318	return s.doTopicResponse(ctx, req)
319}
320
321func (s *mockLiteServer) GetTopicPartitions(ctx context.Context, req *pb.GetTopicPartitionsRequest) (*pb.TopicPartitions, error) {
322	retResponse, retErr := s.popGlobalVerifiers(req)
323	if retErr != nil {
324		return nil, retErr
325	}
326	resp, ok := retResponse.(*pb.TopicPartitions)
327	if !ok {
328		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
329	}
330	return resp, nil
331}
332
333func (s *mockLiteServer) DeleteTopic(ctx context.Context, req *pb.DeleteTopicRequest) (*emptypb.Empty, error) {
334	return s.doEmptyResponse(ctx, req)
335}
336
337func (s *mockLiteServer) CreateSubscription(ctx context.Context, req *pb.CreateSubscriptionRequest) (*pb.Subscription, error) {
338	return s.doSubscriptionResponse(ctx, req)
339}
340
341func (s *mockLiteServer) GetSubscription(ctx context.Context, req *pb.GetSubscriptionRequest) (*pb.Subscription, error) {
342	return s.doSubscriptionResponse(ctx, req)
343}
344
345func (s *mockLiteServer) UpdateSubscription(ctx context.Context, req *pb.UpdateSubscriptionRequest) (*pb.Subscription, error) {
346	return s.doSubscriptionResponse(ctx, req)
347}
348
349func (s *mockLiteServer) DeleteSubscription(ctx context.Context, req *pb.DeleteSubscriptionRequest) (*emptypb.Empty, error) {
350	return s.doEmptyResponse(ctx, req)
351}
352
353func (s *mockLiteServer) ListTopics(ctx context.Context, req *pb.ListTopicsRequest) (*pb.ListTopicsResponse, error) {
354	retResponse, retErr := s.popGlobalVerifiers(req)
355	if retErr != nil {
356		return nil, retErr
357	}
358	resp, ok := retResponse.(*pb.ListTopicsResponse)
359	if !ok {
360		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
361	}
362	return resp, nil
363}
364
365func (s *mockLiteServer) ListTopicSubscriptions(ctx context.Context, req *pb.ListTopicSubscriptionsRequest) (*pb.ListTopicSubscriptionsResponse, error) {
366	retResponse, retErr := s.popGlobalVerifiers(req)
367	if retErr != nil {
368		return nil, retErr
369	}
370	resp, ok := retResponse.(*pb.ListTopicSubscriptionsResponse)
371	if !ok {
372		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
373	}
374	return resp, nil
375}
376
377func (s *mockLiteServer) ListSubscriptions(ctx context.Context, req *pb.ListSubscriptionsRequest) (*pb.ListSubscriptionsResponse, error) {
378	retResponse, retErr := s.popGlobalVerifiers(req)
379	if retErr != nil {
380		return nil, retErr
381	}
382	resp, ok := retResponse.(*pb.ListSubscriptionsResponse)
383	if !ok {
384		return nil, status.Errorf(codes.FailedPrecondition, "mockserver: invalid response type %v", reflect.TypeOf(retResponse))
385	}
386	return resp, nil
387}
388