1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package test
15
16import (
17	"container/list"
18	"fmt"
19	"sync"
20	"testing"
21	"time"
22
23	"cloud.google.com/go/internal/testutil"
24	"google.golang.org/grpc/codes"
25	"google.golang.org/grpc/status"
26)
27
28const (
29	// blockWaitTimeout is the timeout for any wait operations to ensure no
30	// deadlocks.
31	blockWaitTimeout = 30 * time.Second
32)
33
34// Barrier is used to perform two-way synchronization betwen the server and
35// client (test) to ensure tests are deterministic.
36type Barrier struct {
37	// Used to block until the server is ready to send the response.
38	serverBlock chan struct{}
39	// Used to block until the client wants the server to send the response.
40	clientBlock chan struct{}
41	err         error
42}
43
44func newBarrier() *Barrier {
45	return &Barrier{
46		serverBlock: make(chan struct{}),
47		clientBlock: make(chan struct{}),
48	}
49}
50
51// Release should be called by the test.
52func (b *Barrier) Release() {
53	// Wait for the server to reach the barrier.
54	select {
55	case <-time.After(blockWaitTimeout):
56		// Note: avoid returning a retryable code to quickly terminate the test.
57		b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout)
58	case <-b.serverBlock:
59	}
60
61	// Then close the client block.
62	close(b.clientBlock)
63}
64
65func (b *Barrier) serverWait() error {
66	if b.err != nil {
67		return b.err
68	}
69
70	// Close the server block to signal the server reaching the point where it is
71	// ready to send the response.
72	close(b.serverBlock)
73
74	// Wait for the test to release the client block.
75	select {
76	case <-time.After(blockWaitTimeout):
77		// Note: avoid returning a retryable code to quickly terminate the test.
78		return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout)
79	case <-b.clientBlock:
80		return nil
81	}
82}
83
84type rpcMetadata struct {
85	wantRequest interface{}
86	retResponse interface{}
87	retErr      error
88	barrier     *Barrier
89}
90
91// wait until the barrier is released by the test, or a timeout occurs.
92// Returns immediately if there was no block.
93func (r *rpcMetadata) wait() error {
94	if r.barrier == nil {
95		return nil
96	}
97	return r.barrier.serverWait()
98}
99
100// RPCVerifier stores an queue of requests expected from the client, and the
101// corresponding response or error to return.
102type RPCVerifier struct {
103	t        *testing.T
104	mu       sync.Mutex
105	rpcs     *list.List // Value = *rpcMetadata
106	numCalls int
107}
108
109// NewRPCVerifier creates a new verifier for requests received by the server.
110func NewRPCVerifier(t *testing.T) *RPCVerifier {
111	return &RPCVerifier{
112		t:        t,
113		rpcs:     list.New(),
114		numCalls: -1,
115	}
116}
117
118// Push appends a new {request, response, error} tuple.
119//
120// Valid combinations for unary and streaming RPCs:
121// - {request, response, nil}
122// - {request, nil, error}
123//
124// Additional combinations for streams only:
125// - {nil, response, nil}: send a response without a request (e.g. messages).
126// - {nil, nil, error}: break the stream without a request.
127// - {request, nil, nil}: expect a request, but don't send any response.
128func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) {
129	v.mu.Lock()
130	defer v.mu.Unlock()
131
132	v.rpcs.PushBack(&rpcMetadata{
133		wantRequest: wantRequest,
134		retResponse: retResponse,
135		retErr:      retErr,
136	})
137}
138
139// PushWithBarrier is like Push, but returns a barrier that the test should call
140// Release when it would like the response to be sent to the client. This is
141// useful for synchronizing with work that needs to be done on the client.
142func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier {
143	v.mu.Lock()
144	defer v.mu.Unlock()
145
146	barrier := newBarrier()
147	v.rpcs.PushBack(&rpcMetadata{
148		wantRequest: wantRequest,
149		retResponse: retResponse,
150		retErr:      retErr,
151		barrier:     barrier,
152	})
153	return barrier
154}
155
156// Pop validates the received request with the next {request, response, error}
157// tuple.
158func (v *RPCVerifier) Pop(gotRequest interface{}) (interface{}, error) {
159	v.mu.Lock()
160	defer v.mu.Unlock()
161
162	v.numCalls++
163	elem := v.rpcs.Front()
164	if elem == nil {
165		v.t.Errorf("call(%d): unexpected request:\n[%T] %v", v.numCalls, gotRequest, gotRequest)
166		return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected request")
167	}
168
169	rpc, _ := elem.Value.(*rpcMetadata)
170	v.rpcs.Remove(elem)
171
172	if !testutil.Equal(gotRequest, rpc.wantRequest) {
173		v.t.Errorf("call(%d): got request: [%T] %v\nwant request: [%T] %v", v.numCalls, gotRequest, gotRequest, rpc.wantRequest, rpc.wantRequest)
174	}
175	if err := rpc.wait(); err != nil {
176		return nil, err
177	}
178	return rpc.retResponse, rpc.retErr
179}
180
181// TryPop should be used only for streams. It checks whether the request in the
182// next tuple is nil, in which case the response or error should be returned to
183// the client without waiting for a request. Useful for streams where the server
184// continuously sends data (e.g. subscribe stream).
185func (v *RPCVerifier) TryPop() (bool, interface{}, error) {
186	v.mu.Lock()
187	defer v.mu.Unlock()
188
189	elem := v.rpcs.Front()
190	if elem == nil {
191		return false, nil, nil
192	}
193
194	rpc, _ := elem.Value.(*rpcMetadata)
195	if rpc.wantRequest != nil {
196		return false, nil, nil
197	}
198
199	v.rpcs.Remove(elem)
200	if err := rpc.wait(); err != nil {
201		return true, nil, err
202	}
203	return true, rpc.retResponse, rpc.retErr
204}
205
206// Flush logs an error for any remaining {request, response, error} tuples, in
207// case the client terminated early.
208func (v *RPCVerifier) Flush() {
209	v.mu.Lock()
210	defer v.mu.Unlock()
211
212	for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() {
213		v.numCalls++
214		rpc, _ := elem.Value.(*rpcMetadata)
215		if rpc.wantRequest != nil {
216			v.t.Errorf("call(%d): did not receive expected request:\n[%T] %v", v.numCalls, rpc.wantRequest, rpc.wantRequest)
217		} else {
218			v.t.Errorf("call(%d): unsent response:\n[%T] %v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retResponse, rpc.retErr)
219		}
220	}
221	v.rpcs.Init()
222}
223
224// streamVerifiers stores a queue of verifiers for unique stream connections.
225type streamVerifiers struct {
226	t          *testing.T
227	verifiers  *list.List // Value = *RPCVerifier
228	numStreams int
229}
230
231func newStreamVerifiers(t *testing.T) *streamVerifiers {
232	return &streamVerifiers{
233		t:          t,
234		verifiers:  list.New(),
235		numStreams: -1,
236	}
237}
238
239func (sv *streamVerifiers) Push(v *RPCVerifier) {
240	sv.verifiers.PushBack(v)
241}
242
243func (sv *streamVerifiers) Pop() (*RPCVerifier, error) {
244	sv.numStreams++
245	elem := sv.verifiers.Front()
246	if elem == nil {
247		sv.t.Errorf("stream(%d): unexpected connection with no verifiers", sv.numStreams)
248		return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected stream connection")
249	}
250
251	v, _ := elem.Value.(*RPCVerifier)
252	sv.verifiers.Remove(elem)
253	return v, nil
254}
255
256func (sv *streamVerifiers) Flush() {
257	for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() {
258		v, _ := elem.Value.(*RPCVerifier)
259		v.Flush()
260	}
261}
262
263// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys:
264// "streamType:topic_path:partition".
265type keyedStreamVerifiers struct {
266	verifiers map[string]*streamVerifiers
267}
268
269func newKeyedStreamVerifiers() *keyedStreamVerifiers {
270	return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)}
271}
272
273func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) {
274	sv, ok := kv.verifiers[key]
275	if !ok {
276		sv = newStreamVerifiers(v.t)
277		kv.verifiers[key] = sv
278	}
279	sv.Push(v)
280}
281
282func (kv *keyedStreamVerifiers) Pop(key string) (*RPCVerifier, error) {
283	sv, ok := kv.verifiers[key]
284	if !ok {
285		return nil, status.Error(codes.FailedPrecondition, "mockserver: unexpected connection with no configured responses")
286	}
287	return sv.Pop()
288}
289
290func (kv *keyedStreamVerifiers) Flush() {
291	for _, sv := range kv.verifiers {
292		sv.Flush()
293	}
294}
295
296// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs.
297type Verifiers struct {
298	t  *testing.T
299	mu sync.Mutex
300
301	// Global list of verifiers for all unary RPCs.
302	GlobalVerifier *RPCVerifier
303	// Stream verifiers by key.
304	streamVerifiers       *keyedStreamVerifiers
305	activeStreamVerifiers []*RPCVerifier
306}
307
308// NewVerifiers creates a new instance of Verifiers for a test.
309func NewVerifiers(t *testing.T) *Verifiers {
310	return &Verifiers{
311		t:               t,
312		GlobalVerifier:  NewRPCVerifier(t),
313		streamVerifiers: newKeyedStreamVerifiers(),
314	}
315}
316
317// streamType is used as a key prefix for keyedStreamVerifiers.
318type streamType string
319
320const (
321	publishStreamType    streamType = "publish"
322	subscribeStreamType  streamType = "subscribe"
323	commitStreamType     streamType = "commit"
324	assignmentStreamType streamType = "assignment"
325)
326
327func keyPartition(st streamType, path string, partition int) string {
328	return fmt.Sprintf("%s:%s:%d", st, path, partition)
329}
330
331func key(st streamType, path string) string {
332	return fmt.Sprintf("%s:%s", st, path)
333}
334
335// AddPublishStream adds verifiers for a publish stream.
336func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
337	tv.mu.Lock()
338	defer tv.mu.Unlock()
339	tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier)
340}
341
342// AddSubscribeStream adds verifiers for a subscribe stream.
343func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
344	tv.mu.Lock()
345	defer tv.mu.Unlock()
346	tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier)
347}
348
349// AddCommitStream adds verifiers for a commit stream.
350func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
351	tv.mu.Lock()
352	defer tv.mu.Unlock()
353	tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier)
354}
355
356// AddAssignmentStream adds verifiers for an assignment stream.
357func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
358	tv.mu.Lock()
359	defer tv.mu.Unlock()
360	tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier)
361}
362
363func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) {
364	tv.mu.Lock()
365	defer tv.mu.Unlock()
366	v, err := tv.streamVerifiers.Pop(key)
367	if v != nil {
368		tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v)
369	}
370	return v, err
371}
372
373func (tv *Verifiers) flush() {
374	tv.mu.Lock()
375	defer tv.mu.Unlock()
376
377	tv.GlobalVerifier.Flush()
378	tv.streamVerifiers.Flush()
379	for _, v := range tv.activeStreamVerifiers {
380		v.Flush()
381	}
382}
383