1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13 14package test 15 16import ( 17 "container/list" 18 "fmt" 19 "sync" 20 "testing" 21 "time" 22 23 "cloud.google.com/go/internal/testutil" 24 "google.golang.org/grpc/codes" 25 "google.golang.org/grpc/status" 26) 27 28const ( 29 // blockWaitTimeout is the timeout for any wait operations to ensure no 30 // deadlocks. 31 blockWaitTimeout = 30 * time.Second 32) 33 34// Barrier is used to perform two-way synchronization betwen the server and 35// client (test) to ensure tests are deterministic. 36type Barrier struct { 37 // Used to block until the server is ready to send the response. 38 serverBlock chan struct{} 39 // Used to block until the client wants the server to send the response. 40 clientBlock chan struct{} 41 err error 42} 43 44func newBarrier() *Barrier { 45 return &Barrier{ 46 serverBlock: make(chan struct{}), 47 clientBlock: make(chan struct{}), 48 } 49} 50 51// Release should be called by the test. 52func (b *Barrier) Release() { 53 // Wait for the server to reach the barrier. 54 select { 55 case <-time.After(blockWaitTimeout): 56 // Note: avoid returning a retryable code to quickly terminate the test. 57 b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout) 58 case <-b.serverBlock: 59 } 60 61 // Then close the client block. 62 close(b.clientBlock) 63} 64 65func (b *Barrier) serverWait() error { 66 if b.err != nil { 67 return b.err 68 } 69 70 // Close the server block to signal the server reaching the point where it is 71 // ready to send the response. 72 close(b.serverBlock) 73 74 // Wait for the test to release the client block. 75 select { 76 case <-time.After(blockWaitTimeout): 77 // Note: avoid returning a retryable code to quickly terminate the test. 78 return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout) 79 case <-b.clientBlock: 80 return nil 81 } 82} 83 84type rpcMetadata struct { 85 wantRequest interface{} 86 retResponse interface{} 87 retErr error 88 barrier *Barrier 89} 90 91// wait until the barrier is released by the test, or a timeout occurs. 92// Returns immediately if there was no block. 93func (r *rpcMetadata) wait() error { 94 if r.barrier == nil { 95 return nil 96 } 97 return r.barrier.serverWait() 98} 99 100// RPCVerifier stores an queue of requests expected from the client, and the 101// corresponding response or error to return. 102type RPCVerifier struct { 103 t *testing.T 104 mu sync.Mutex 105 rpcs *list.List // Value = *rpcMetadata 106 numCalls int 107} 108 109// NewRPCVerifier creates a new verifier for requests received by the server. 110func NewRPCVerifier(t *testing.T) *RPCVerifier { 111 return &RPCVerifier{ 112 t: t, 113 rpcs: list.New(), 114 numCalls: -1, 115 } 116} 117 118// Push appends a new {request, response, error} tuple. 119// 120// Valid combinations for unary and streaming RPCs: 121// - {request, response, nil} 122// - {request, nil, error} 123// 124// Additional combinations for streams only: 125// - {nil, response, nil}: send a response without a request (e.g. messages). 126// - {nil, nil, error}: break the stream without a request. 127// - {request, nil, nil}: expect a request, but don't send any response. 128func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) { 129 v.mu.Lock() 130 defer v.mu.Unlock() 131 132 v.rpcs.PushBack(&rpcMetadata{ 133 wantRequest: wantRequest, 134 retResponse: retResponse, 135 retErr: retErr, 136 }) 137} 138 139// PushWithBarrier is like Push, but returns a barrier that the test should call 140// Release when it would like the response to be sent to the client. This is 141// useful for synchronizing with work that needs to be done on the client. 142func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier { 143 v.mu.Lock() 144 defer v.mu.Unlock() 145 146 barrier := newBarrier() 147 v.rpcs.PushBack(&rpcMetadata{ 148 wantRequest: wantRequest, 149 retResponse: retResponse, 150 retErr: retErr, 151 barrier: barrier, 152 }) 153 return barrier 154} 155 156// Pop validates the received request with the next {request, response, error} 157// tuple. 158func (v *RPCVerifier) Pop(gotRequest interface{}) (interface{}, error) { 159 v.mu.Lock() 160 defer v.mu.Unlock() 161 162 v.numCalls++ 163 elem := v.rpcs.Front() 164 if elem == nil { 165 v.t.Errorf("call(%d): unexpected request:\n[%T] %v", v.numCalls, gotRequest, gotRequest) 166 return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected request") 167 } 168 169 rpc, _ := elem.Value.(*rpcMetadata) 170 v.rpcs.Remove(elem) 171 172 if !testutil.Equal(gotRequest, rpc.wantRequest) { 173 v.t.Errorf("call(%d): got request: [%T] %v\nwant request: [%T] %v", v.numCalls, gotRequest, gotRequest, rpc.wantRequest, rpc.wantRequest) 174 } 175 if err := rpc.wait(); err != nil { 176 return nil, err 177 } 178 return rpc.retResponse, rpc.retErr 179} 180 181// TryPop should be used only for streams. It checks whether the request in the 182// next tuple is nil, in which case the response or error should be returned to 183// the client without waiting for a request. Useful for streams where the server 184// continuously sends data (e.g. subscribe stream). 185func (v *RPCVerifier) TryPop() (bool, interface{}, error) { 186 v.mu.Lock() 187 defer v.mu.Unlock() 188 189 elem := v.rpcs.Front() 190 if elem == nil { 191 return false, nil, nil 192 } 193 194 rpc, _ := elem.Value.(*rpcMetadata) 195 if rpc.wantRequest != nil { 196 return false, nil, nil 197 } 198 199 v.rpcs.Remove(elem) 200 if err := rpc.wait(); err != nil { 201 return true, nil, err 202 } 203 return true, rpc.retResponse, rpc.retErr 204} 205 206// Flush logs an error for any remaining {request, response, error} tuples, in 207// case the client terminated early. 208func (v *RPCVerifier) Flush() { 209 v.mu.Lock() 210 defer v.mu.Unlock() 211 212 for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() { 213 v.numCalls++ 214 rpc, _ := elem.Value.(*rpcMetadata) 215 if rpc.wantRequest != nil { 216 v.t.Errorf("call(%d): did not receive expected request:\n[%T] %v", v.numCalls, rpc.wantRequest, rpc.wantRequest) 217 } else { 218 v.t.Errorf("call(%d): unsent response:\n[%T] %v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retResponse, rpc.retErr) 219 } 220 } 221 v.rpcs.Init() 222} 223 224// streamVerifiers stores a queue of verifiers for unique stream connections. 225type streamVerifiers struct { 226 t *testing.T 227 verifiers *list.List // Value = *RPCVerifier 228 numStreams int 229} 230 231func newStreamVerifiers(t *testing.T) *streamVerifiers { 232 return &streamVerifiers{ 233 t: t, 234 verifiers: list.New(), 235 numStreams: -1, 236 } 237} 238 239func (sv *streamVerifiers) Push(v *RPCVerifier) { 240 sv.verifiers.PushBack(v) 241} 242 243func (sv *streamVerifiers) Pop() (*RPCVerifier, error) { 244 sv.numStreams++ 245 elem := sv.verifiers.Front() 246 if elem == nil { 247 sv.t.Errorf("stream(%d): unexpected connection with no verifiers", sv.numStreams) 248 return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected stream connection") 249 } 250 251 v, _ := elem.Value.(*RPCVerifier) 252 sv.verifiers.Remove(elem) 253 return v, nil 254} 255 256func (sv *streamVerifiers) Flush() { 257 for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() { 258 v, _ := elem.Value.(*RPCVerifier) 259 v.Flush() 260 } 261} 262 263// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys: 264// "streamType:topic_path:partition". 265type keyedStreamVerifiers struct { 266 verifiers map[string]*streamVerifiers 267} 268 269func newKeyedStreamVerifiers() *keyedStreamVerifiers { 270 return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)} 271} 272 273func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) { 274 sv, ok := kv.verifiers[key] 275 if !ok { 276 sv = newStreamVerifiers(v.t) 277 kv.verifiers[key] = sv 278 } 279 sv.Push(v) 280} 281 282func (kv *keyedStreamVerifiers) Pop(key string) (*RPCVerifier, error) { 283 sv, ok := kv.verifiers[key] 284 if !ok { 285 return nil, status.Error(codes.FailedPrecondition, "mockserver: unexpected connection with no configured responses") 286 } 287 return sv.Pop() 288} 289 290func (kv *keyedStreamVerifiers) Flush() { 291 for _, sv := range kv.verifiers { 292 sv.Flush() 293 } 294} 295 296// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs. 297type Verifiers struct { 298 t *testing.T 299 mu sync.Mutex 300 301 // Global list of verifiers for all unary RPCs. 302 GlobalVerifier *RPCVerifier 303 // Stream verifiers by key. 304 streamVerifiers *keyedStreamVerifiers 305 activeStreamVerifiers []*RPCVerifier 306} 307 308// NewVerifiers creates a new instance of Verifiers for a test. 309func NewVerifiers(t *testing.T) *Verifiers { 310 return &Verifiers{ 311 t: t, 312 GlobalVerifier: NewRPCVerifier(t), 313 streamVerifiers: newKeyedStreamVerifiers(), 314 } 315} 316 317// streamType is used as a key prefix for keyedStreamVerifiers. 318type streamType string 319 320const ( 321 publishStreamType streamType = "publish" 322 subscribeStreamType streamType = "subscribe" 323 commitStreamType streamType = "commit" 324 assignmentStreamType streamType = "assignment" 325) 326 327func keyPartition(st streamType, path string, partition int) string { 328 return fmt.Sprintf("%s:%s:%d", st, path, partition) 329} 330 331func key(st streamType, path string) string { 332 return fmt.Sprintf("%s:%s", st, path) 333} 334 335// AddPublishStream adds verifiers for a publish stream. 336func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) { 337 tv.mu.Lock() 338 defer tv.mu.Unlock() 339 tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier) 340} 341 342// AddSubscribeStream adds verifiers for a subscribe stream. 343func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) { 344 tv.mu.Lock() 345 defer tv.mu.Unlock() 346 tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier) 347} 348 349// AddCommitStream adds verifiers for a commit stream. 350func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) { 351 tv.mu.Lock() 352 defer tv.mu.Unlock() 353 tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier) 354} 355 356// AddAssignmentStream adds verifiers for an assignment stream. 357func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) { 358 tv.mu.Lock() 359 defer tv.mu.Unlock() 360 tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier) 361} 362 363func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) { 364 tv.mu.Lock() 365 defer tv.mu.Unlock() 366 v, err := tv.streamVerifiers.Pop(key) 367 if v != nil { 368 tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v) 369 } 370 return v, err 371} 372 373func (tv *Verifiers) flush() { 374 tv.mu.Lock() 375 defer tv.mu.Unlock() 376 377 tv.GlobalVerifier.Flush() 378 tv.streamVerifiers.Flush() 379 for _, v := range tv.activeStreamVerifiers { 380 v.Flush() 381 } 382} 383