1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package test
15
16import (
17	"container/list"
18	"fmt"
19	"sync"
20	"testing"
21	"time"
22
23	"cloud.google.com/go/internal/testutil"
24	"google.golang.org/grpc/codes"
25	"google.golang.org/grpc/status"
26)
27
28const (
29	// blockWaitTimeout is the timeout for any wait operations to ensure no
30	// deadlocks.
31	blockWaitTimeout = 30 * time.Second
32)
33
34// Barrier is used to perform two-way synchronization betwen the server and
35// client (test) to ensure tests are deterministic.
36type Barrier struct {
37	// Used to block until the server is ready to send the response.
38	serverBlock chan struct{}
39	// Used to block until the client wants the server to send the response.
40	clientBlock chan struct{}
41	err         error
42}
43
44func newBarrier() *Barrier {
45	return &Barrier{
46		serverBlock: make(chan struct{}),
47		clientBlock: make(chan struct{}),
48	}
49}
50
51// ReleaseAfter releases the barrier, after invoking f provided by the test.
52func (b *Barrier) ReleaseAfter(f func()) {
53	// Wait for the server to reach the barrier.
54	select {
55	case <-time.After(blockWaitTimeout):
56		// Note: avoid returning a retryable code to quickly terminate the test.
57		b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout)
58	case <-b.serverBlock:
59	}
60
61	// Run any test-specific code.
62	if f != nil {
63		f()
64	}
65
66	// Then close the client block.
67	close(b.clientBlock)
68}
69
70// Release should be called by the test.
71func (b *Barrier) Release() {
72	b.ReleaseAfter(nil)
73}
74
75func (b *Barrier) serverWait() error {
76	if b.err != nil {
77		return b.err
78	}
79
80	// Close the server block to signal the server reaching the point where it is
81	// ready to send the response.
82	close(b.serverBlock)
83
84	// Wait for the test to release the client block.
85	select {
86	case <-time.After(blockWaitTimeout):
87		// Note: avoid returning a retryable code to quickly terminate the test.
88		return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout)
89	case <-b.clientBlock:
90		return nil
91	}
92}
93
94type rpcMetadata struct {
95	wantRequest interface{}
96	retResponse interface{}
97	retErr      error
98	barrier     *Barrier
99}
100
101// wait until the barrier is released by the test, or a timeout occurs.
102// Returns immediately if there was no block.
103func (r *rpcMetadata) wait() error {
104	if r.barrier == nil {
105		return nil
106	}
107	return r.barrier.serverWait()
108}
109
110// RPCVerifier stores an queue of requests expected from the client, and the
111// corresponding response or error to return.
112type RPCVerifier struct {
113	t        *testing.T
114	mu       sync.Mutex
115	rpcs     *list.List // Value = *rpcMetadata
116	numCalls int
117}
118
119// NewRPCVerifier creates a new verifier for requests received by the server.
120func NewRPCVerifier(t *testing.T) *RPCVerifier {
121	return &RPCVerifier{
122		t:        t,
123		rpcs:     list.New(),
124		numCalls: -1,
125	}
126}
127
128// Push appends a new {request, response, error} tuple.
129//
130// Valid combinations for unary and streaming RPCs:
131// - {request, response, nil}
132// - {request, nil, error}
133//
134// Additional combinations for streams only:
135// - {nil, response, nil}: send a response without a request (e.g. messages).
136// - {nil, nil, error}: break the stream without a request.
137// - {request, nil, nil}: expect a request, but don't send any response.
138func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) {
139	v.mu.Lock()
140	defer v.mu.Unlock()
141
142	v.rpcs.PushBack(&rpcMetadata{
143		wantRequest: wantRequest,
144		retResponse: retResponse,
145		retErr:      retErr,
146	})
147}
148
149// PushWithBarrier is like Push, but returns a barrier that the test should call
150// Release when it would like the response to be sent to the client. This is
151// useful for synchronizing with work that needs to be done on the client.
152func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier {
153	v.mu.Lock()
154	defer v.mu.Unlock()
155
156	barrier := newBarrier()
157	v.rpcs.PushBack(&rpcMetadata{
158		wantRequest: wantRequest,
159		retResponse: retResponse,
160		retErr:      retErr,
161		barrier:     barrier,
162	})
163	return barrier
164}
165
166// Pop validates the received request with the next {request, response, error}
167// tuple.
168func (v *RPCVerifier) Pop(gotRequest interface{}) (interface{}, error) {
169	v.mu.Lock()
170	defer v.mu.Unlock()
171
172	v.numCalls++
173	elem := v.rpcs.Front()
174	if elem == nil {
175		v.t.Errorf("call(%d): unexpected request:\n[%T] %v", v.numCalls, gotRequest, gotRequest)
176		return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected request")
177	}
178
179	rpc, _ := elem.Value.(*rpcMetadata)
180	v.rpcs.Remove(elem)
181
182	if !testutil.Equal(gotRequest, rpc.wantRequest) {
183		v.t.Errorf("call(%d): got request: [%T] %v\nwant request: [%T] %v", v.numCalls, gotRequest, gotRequest, rpc.wantRequest, rpc.wantRequest)
184	}
185	if err := rpc.wait(); err != nil {
186		return nil, err
187	}
188	return rpc.retResponse, rpc.retErr
189}
190
191// TryPop should be used only for streams. It checks whether the request in the
192// next tuple is nil, in which case the response or error should be returned to
193// the client without waiting for a request. Useful for streams where the server
194// continuously sends data (e.g. subscribe stream).
195func (v *RPCVerifier) TryPop() (bool, interface{}, error) {
196	v.mu.Lock()
197	defer v.mu.Unlock()
198
199	elem := v.rpcs.Front()
200	if elem == nil {
201		return false, nil, nil
202	}
203
204	rpc, _ := elem.Value.(*rpcMetadata)
205	if rpc.wantRequest != nil {
206		return false, nil, nil
207	}
208
209	v.rpcs.Remove(elem)
210	if err := rpc.wait(); err != nil {
211		return true, nil, err
212	}
213	return true, rpc.retResponse, rpc.retErr
214}
215
216// Flush logs an error for any remaining {request, response, error} tuples, in
217// case the client terminated early.
218func (v *RPCVerifier) Flush() {
219	v.mu.Lock()
220	defer v.mu.Unlock()
221
222	for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() {
223		v.numCalls++
224		rpc, _ := elem.Value.(*rpcMetadata)
225		if rpc.wantRequest != nil {
226			v.t.Errorf("call(%d): did not receive expected request:\n[%T] %v", v.numCalls, rpc.wantRequest, rpc.wantRequest)
227		} else {
228			v.t.Errorf("call(%d): unsent response:\n[%T] %v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retResponse, rpc.retErr)
229		}
230	}
231	v.rpcs.Init()
232}
233
234// streamVerifiers stores a queue of verifiers for unique stream connections.
235type streamVerifiers struct {
236	t          *testing.T
237	verifiers  *list.List // Value = *RPCVerifier
238	numStreams int
239}
240
241func newStreamVerifiers(t *testing.T) *streamVerifiers {
242	return &streamVerifiers{
243		t:          t,
244		verifiers:  list.New(),
245		numStreams: -1,
246	}
247}
248
249func (sv *streamVerifiers) Push(v *RPCVerifier) {
250	sv.verifiers.PushBack(v)
251}
252
253func (sv *streamVerifiers) Pop() (*RPCVerifier, error) {
254	sv.numStreams++
255	elem := sv.verifiers.Front()
256	if elem == nil {
257		sv.t.Errorf("stream(%d): unexpected connection with no verifiers", sv.numStreams)
258		return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected stream connection")
259	}
260
261	v, _ := elem.Value.(*RPCVerifier)
262	sv.verifiers.Remove(elem)
263	return v, nil
264}
265
266func (sv *streamVerifiers) Flush() {
267	for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() {
268		v, _ := elem.Value.(*RPCVerifier)
269		v.Flush()
270	}
271}
272
273// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys:
274// "streamType:topic_path:partition".
275type keyedStreamVerifiers struct {
276	verifiers map[string]*streamVerifiers
277}
278
279func newKeyedStreamVerifiers() *keyedStreamVerifiers {
280	return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)}
281}
282
283func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) {
284	sv, ok := kv.verifiers[key]
285	if !ok {
286		sv = newStreamVerifiers(v.t)
287		kv.verifiers[key] = sv
288	}
289	sv.Push(v)
290}
291
292func (kv *keyedStreamVerifiers) Pop(key string) (*RPCVerifier, error) {
293	sv, ok := kv.verifiers[key]
294	if !ok {
295		return nil, status.Error(codes.FailedPrecondition, "mockserver: unexpected connection with no configured responses")
296	}
297	return sv.Pop()
298}
299
300func (kv *keyedStreamVerifiers) Flush() {
301	for _, sv := range kv.verifiers {
302		sv.Flush()
303	}
304}
305
306// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs.
307type Verifiers struct {
308	t  *testing.T
309	mu sync.Mutex
310
311	// Global list of verifiers for all unary RPCs.
312	GlobalVerifier *RPCVerifier
313	// Stream verifiers by key.
314	streamVerifiers       *keyedStreamVerifiers
315	activeStreamVerifiers []*RPCVerifier
316}
317
318// NewVerifiers creates a new instance of Verifiers for a test.
319func NewVerifiers(t *testing.T) *Verifiers {
320	return &Verifiers{
321		t:               t,
322		GlobalVerifier:  NewRPCVerifier(t),
323		streamVerifiers: newKeyedStreamVerifiers(),
324	}
325}
326
327// streamType is used as a key prefix for keyedStreamVerifiers.
328type streamType string
329
330const (
331	publishStreamType    streamType = "publish"
332	subscribeStreamType  streamType = "subscribe"
333	commitStreamType     streamType = "commit"
334	assignmentStreamType streamType = "assignment"
335)
336
337func keyPartition(st streamType, path string, partition int) string {
338	return fmt.Sprintf("%s:%s:%d", st, path, partition)
339}
340
341func key(st streamType, path string) string {
342	return fmt.Sprintf("%s:%s", st, path)
343}
344
345// AddPublishStream adds verifiers for a publish stream.
346func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) {
347	tv.mu.Lock()
348	defer tv.mu.Unlock()
349	tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier)
350}
351
352// AddSubscribeStream adds verifiers for a subscribe stream.
353func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) {
354	tv.mu.Lock()
355	defer tv.mu.Unlock()
356	tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier)
357}
358
359// AddCommitStream adds verifiers for a commit stream.
360func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) {
361	tv.mu.Lock()
362	defer tv.mu.Unlock()
363	tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier)
364}
365
366// AddAssignmentStream adds verifiers for an assignment stream.
367func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) {
368	tv.mu.Lock()
369	defer tv.mu.Unlock()
370	tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier)
371}
372
373func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) {
374	tv.mu.Lock()
375	defer tv.mu.Unlock()
376	v, err := tv.streamVerifiers.Pop(key)
377	if v != nil {
378		tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v)
379	}
380	return v, err
381}
382
383func (tv *Verifiers) flush() {
384	tv.mu.Lock()
385	defer tv.mu.Unlock()
386
387	tv.GlobalVerifier.Flush()
388	tv.streamVerifiers.Flush()
389	for _, v := range tv.activeStreamVerifiers {
390		v.Flush()
391	}
392}
393