1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13 14package test 15 16import ( 17 "container/list" 18 "fmt" 19 "sync" 20 "testing" 21 "time" 22 23 "cloud.google.com/go/internal/testutil" 24 "google.golang.org/grpc/codes" 25 "google.golang.org/grpc/status" 26) 27 28const ( 29 // blockWaitTimeout is the timeout for any wait operations to ensure no 30 // deadlocks. 31 blockWaitTimeout = 30 * time.Second 32) 33 34// Barrier is used to perform two-way synchronization betwen the server and 35// client (test) to ensure tests are deterministic. 36type Barrier struct { 37 // Used to block until the server is ready to send the response. 38 serverBlock chan struct{} 39 // Used to block until the client wants the server to send the response. 40 clientBlock chan struct{} 41 err error 42} 43 44func newBarrier() *Barrier { 45 return &Barrier{ 46 serverBlock: make(chan struct{}), 47 clientBlock: make(chan struct{}), 48 } 49} 50 51// ReleaseAfter releases the barrier, after invoking f provided by the test. 52func (b *Barrier) ReleaseAfter(f func()) { 53 // Wait for the server to reach the barrier. 54 select { 55 case <-time.After(blockWaitTimeout): 56 // Note: avoid returning a retryable code to quickly terminate the test. 57 b.err = status.Errorf(codes.FailedPrecondition, "mockserver: server did not reach barrier within %v", blockWaitTimeout) 58 case <-b.serverBlock: 59 } 60 61 // Run any test-specific code. 62 if f != nil { 63 f() 64 } 65 66 // Then close the client block. 67 close(b.clientBlock) 68} 69 70// Release should be called by the test. 71func (b *Barrier) Release() { 72 b.ReleaseAfter(nil) 73} 74 75func (b *Barrier) serverWait() error { 76 if b.err != nil { 77 return b.err 78 } 79 80 // Close the server block to signal the server reaching the point where it is 81 // ready to send the response. 82 close(b.serverBlock) 83 84 // Wait for the test to release the client block. 85 select { 86 case <-time.After(blockWaitTimeout): 87 // Note: avoid returning a retryable code to quickly terminate the test. 88 return status.Errorf(codes.FailedPrecondition, "mockserver: test did not unblock response within %v", blockWaitTimeout) 89 case <-b.clientBlock: 90 return nil 91 } 92} 93 94type rpcMetadata struct { 95 wantRequest interface{} 96 retResponse interface{} 97 retErr error 98 barrier *Barrier 99} 100 101// wait until the barrier is released by the test, or a timeout occurs. 102// Returns immediately if there was no block. 103func (r *rpcMetadata) wait() error { 104 if r.barrier == nil { 105 return nil 106 } 107 return r.barrier.serverWait() 108} 109 110// RPCVerifier stores an queue of requests expected from the client, and the 111// corresponding response or error to return. 112type RPCVerifier struct { 113 t *testing.T 114 mu sync.Mutex 115 rpcs *list.List // Value = *rpcMetadata 116 numCalls int 117} 118 119// NewRPCVerifier creates a new verifier for requests received by the server. 120func NewRPCVerifier(t *testing.T) *RPCVerifier { 121 return &RPCVerifier{ 122 t: t, 123 rpcs: list.New(), 124 numCalls: -1, 125 } 126} 127 128// Push appends a new {request, response, error} tuple. 129// 130// Valid combinations for unary and streaming RPCs: 131// - {request, response, nil} 132// - {request, nil, error} 133// 134// Additional combinations for streams only: 135// - {nil, response, nil}: send a response without a request (e.g. messages). 136// - {nil, nil, error}: break the stream without a request. 137// - {request, nil, nil}: expect a request, but don't send any response. 138func (v *RPCVerifier) Push(wantRequest interface{}, retResponse interface{}, retErr error) { 139 v.mu.Lock() 140 defer v.mu.Unlock() 141 142 v.rpcs.PushBack(&rpcMetadata{ 143 wantRequest: wantRequest, 144 retResponse: retResponse, 145 retErr: retErr, 146 }) 147} 148 149// PushWithBarrier is like Push, but returns a barrier that the test should call 150// Release when it would like the response to be sent to the client. This is 151// useful for synchronizing with work that needs to be done on the client. 152func (v *RPCVerifier) PushWithBarrier(wantRequest interface{}, retResponse interface{}, retErr error) *Barrier { 153 v.mu.Lock() 154 defer v.mu.Unlock() 155 156 barrier := newBarrier() 157 v.rpcs.PushBack(&rpcMetadata{ 158 wantRequest: wantRequest, 159 retResponse: retResponse, 160 retErr: retErr, 161 barrier: barrier, 162 }) 163 return barrier 164} 165 166// Pop validates the received request with the next {request, response, error} 167// tuple. 168func (v *RPCVerifier) Pop(gotRequest interface{}) (interface{}, error) { 169 v.mu.Lock() 170 defer v.mu.Unlock() 171 172 v.numCalls++ 173 elem := v.rpcs.Front() 174 if elem == nil { 175 v.t.Errorf("call(%d): unexpected request:\n[%T] %v", v.numCalls, gotRequest, gotRequest) 176 return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected request") 177 } 178 179 rpc, _ := elem.Value.(*rpcMetadata) 180 v.rpcs.Remove(elem) 181 182 if !testutil.Equal(gotRequest, rpc.wantRequest) { 183 v.t.Errorf("call(%d): got request: [%T] %v\nwant request: [%T] %v", v.numCalls, gotRequest, gotRequest, rpc.wantRequest, rpc.wantRequest) 184 } 185 if err := rpc.wait(); err != nil { 186 return nil, err 187 } 188 return rpc.retResponse, rpc.retErr 189} 190 191// TryPop should be used only for streams. It checks whether the request in the 192// next tuple is nil, in which case the response or error should be returned to 193// the client without waiting for a request. Useful for streams where the server 194// continuously sends data (e.g. subscribe stream). 195func (v *RPCVerifier) TryPop() (bool, interface{}, error) { 196 v.mu.Lock() 197 defer v.mu.Unlock() 198 199 elem := v.rpcs.Front() 200 if elem == nil { 201 return false, nil, nil 202 } 203 204 rpc, _ := elem.Value.(*rpcMetadata) 205 if rpc.wantRequest != nil { 206 return false, nil, nil 207 } 208 209 v.rpcs.Remove(elem) 210 if err := rpc.wait(); err != nil { 211 return true, nil, err 212 } 213 return true, rpc.retResponse, rpc.retErr 214} 215 216// Flush logs an error for any remaining {request, response, error} tuples, in 217// case the client terminated early. 218func (v *RPCVerifier) Flush() { 219 v.mu.Lock() 220 defer v.mu.Unlock() 221 222 for elem := v.rpcs.Front(); elem != nil; elem = elem.Next() { 223 v.numCalls++ 224 rpc, _ := elem.Value.(*rpcMetadata) 225 if rpc.wantRequest != nil { 226 v.t.Errorf("call(%d): did not receive expected request:\n[%T] %v", v.numCalls, rpc.wantRequest, rpc.wantRequest) 227 } else { 228 v.t.Errorf("call(%d): unsent response:\n[%T] %v, err = (%v)", v.numCalls, rpc.retResponse, rpc.retResponse, rpc.retErr) 229 } 230 } 231 v.rpcs.Init() 232} 233 234// streamVerifiers stores a queue of verifiers for unique stream connections. 235type streamVerifiers struct { 236 t *testing.T 237 verifiers *list.List // Value = *RPCVerifier 238 numStreams int 239} 240 241func newStreamVerifiers(t *testing.T) *streamVerifiers { 242 return &streamVerifiers{ 243 t: t, 244 verifiers: list.New(), 245 numStreams: -1, 246 } 247} 248 249func (sv *streamVerifiers) Push(v *RPCVerifier) { 250 sv.verifiers.PushBack(v) 251} 252 253func (sv *streamVerifiers) Pop() (*RPCVerifier, error) { 254 sv.numStreams++ 255 elem := sv.verifiers.Front() 256 if elem == nil { 257 sv.t.Errorf("stream(%d): unexpected connection with no verifiers", sv.numStreams) 258 return nil, status.Error(codes.FailedPrecondition, "mockserver: got unexpected stream connection") 259 } 260 261 v, _ := elem.Value.(*RPCVerifier) 262 sv.verifiers.Remove(elem) 263 return v, nil 264} 265 266func (sv *streamVerifiers) Flush() { 267 for elem := sv.verifiers.Front(); elem != nil; elem = elem.Next() { 268 v, _ := elem.Value.(*RPCVerifier) 269 v.Flush() 270 } 271} 272 273// keyedStreamVerifiers stores indexed streamVerifiers. Examples of keys: 274// "streamType:topic_path:partition". 275type keyedStreamVerifiers struct { 276 verifiers map[string]*streamVerifiers 277} 278 279func newKeyedStreamVerifiers() *keyedStreamVerifiers { 280 return &keyedStreamVerifiers{verifiers: make(map[string]*streamVerifiers)} 281} 282 283func (kv *keyedStreamVerifiers) Push(key string, v *RPCVerifier) { 284 sv, ok := kv.verifiers[key] 285 if !ok { 286 sv = newStreamVerifiers(v.t) 287 kv.verifiers[key] = sv 288 } 289 sv.Push(v) 290} 291 292func (kv *keyedStreamVerifiers) Pop(key string) (*RPCVerifier, error) { 293 sv, ok := kv.verifiers[key] 294 if !ok { 295 return nil, status.Error(codes.FailedPrecondition, "mockserver: unexpected connection with no configured responses") 296 } 297 return sv.Pop() 298} 299 300func (kv *keyedStreamVerifiers) Flush() { 301 for _, sv := range kv.verifiers { 302 sv.Flush() 303 } 304} 305 306// Verifiers contains RPCVerifiers for unary RPCs and streaming RPCs. 307type Verifiers struct { 308 t *testing.T 309 mu sync.Mutex 310 311 // Global list of verifiers for all unary RPCs. 312 GlobalVerifier *RPCVerifier 313 // Stream verifiers by key. 314 streamVerifiers *keyedStreamVerifiers 315 activeStreamVerifiers []*RPCVerifier 316} 317 318// NewVerifiers creates a new instance of Verifiers for a test. 319func NewVerifiers(t *testing.T) *Verifiers { 320 return &Verifiers{ 321 t: t, 322 GlobalVerifier: NewRPCVerifier(t), 323 streamVerifiers: newKeyedStreamVerifiers(), 324 } 325} 326 327// streamType is used as a key prefix for keyedStreamVerifiers. 328type streamType string 329 330const ( 331 publishStreamType streamType = "publish" 332 subscribeStreamType streamType = "subscribe" 333 commitStreamType streamType = "commit" 334 assignmentStreamType streamType = "assignment" 335) 336 337func keyPartition(st streamType, path string, partition int) string { 338 return fmt.Sprintf("%s:%s:%d", st, path, partition) 339} 340 341func key(st streamType, path string) string { 342 return fmt.Sprintf("%s:%s", st, path) 343} 344 345// AddPublishStream adds verifiers for a publish stream. 346func (tv *Verifiers) AddPublishStream(topic string, partition int, streamVerifier *RPCVerifier) { 347 tv.mu.Lock() 348 defer tv.mu.Unlock() 349 tv.streamVerifiers.Push(keyPartition(publishStreamType, topic, partition), streamVerifier) 350} 351 352// AddSubscribeStream adds verifiers for a subscribe stream. 353func (tv *Verifiers) AddSubscribeStream(subscription string, partition int, streamVerifier *RPCVerifier) { 354 tv.mu.Lock() 355 defer tv.mu.Unlock() 356 tv.streamVerifiers.Push(keyPartition(subscribeStreamType, subscription, partition), streamVerifier) 357} 358 359// AddCommitStream adds verifiers for a commit stream. 360func (tv *Verifiers) AddCommitStream(subscription string, partition int, streamVerifier *RPCVerifier) { 361 tv.mu.Lock() 362 defer tv.mu.Unlock() 363 tv.streamVerifiers.Push(keyPartition(commitStreamType, subscription, partition), streamVerifier) 364} 365 366// AddAssignmentStream adds verifiers for an assignment stream. 367func (tv *Verifiers) AddAssignmentStream(subscription string, streamVerifier *RPCVerifier) { 368 tv.mu.Lock() 369 defer tv.mu.Unlock() 370 tv.streamVerifiers.Push(key(assignmentStreamType, subscription), streamVerifier) 371} 372 373func (tv *Verifiers) popStreamVerifier(key string) (*RPCVerifier, error) { 374 tv.mu.Lock() 375 defer tv.mu.Unlock() 376 v, err := tv.streamVerifiers.Pop(key) 377 if v != nil { 378 tv.activeStreamVerifiers = append(tv.activeStreamVerifiers, v) 379 } 380 return v, err 381} 382 383func (tv *Verifiers) flush() { 384 tv.mu.Lock() 385 defer tv.mu.Unlock() 386 387 tv.GlobalVerifier.Flush() 388 tv.streamVerifiers.Flush() 389 for _, v := range tv.activeStreamVerifiers { 390 v.Flush() 391 } 392} 393