1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"errors"
18	"fmt"
19	"testing"
20	"time"
21
22	"cloud.google.com/go/pubsublite/internal/test"
23)
24
25const receiveStatusTimeout = 5 * time.Second
26
27type testStatusChangeReceiver struct {
28	// Status change notifications are fired asynchronously, so a channel receives
29	// the statuses.
30	statusC    chan serviceStatus
31	lastStatus serviceStatus
32	name       string
33}
34
35func newTestStatusChangeReceiver(name string) *testStatusChangeReceiver {
36	return &testStatusChangeReceiver{
37		statusC: make(chan serviceStatus, 1),
38		name:    name,
39	}
40}
41
42func (sr *testStatusChangeReceiver) Handle() interface{} { return sr }
43
44func (sr *testStatusChangeReceiver) OnStatusChange(handle serviceHandle, status serviceStatus, err error) {
45	sr.statusC <- status
46}
47
48func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatus) {
49	select {
50	case status := <-sr.statusC:
51		if status <= sr.lastStatus {
52			t.Errorf("%s: Duplicate service status: %d, last status: %d", sr.name, status, sr.lastStatus)
53		}
54		if status != want {
55			t.Errorf("%s: Got service status: %d, want: %d", sr.name, status, want)
56		}
57		sr.lastStatus = status
58	case <-time.After(receiveStatusTimeout):
59		t.Errorf("%s: Did not receive status within %v", sr.name, receiveStatusTimeout)
60	}
61}
62
63func (sr *testStatusChangeReceiver) VerifyNoStatusChanges(t *testing.T) {
64	select {
65	case status := <-sr.statusC:
66		t.Errorf("%s: Unexpected service status: %d", sr.name, status)
67	default:
68	}
69}
70
71type testService struct {
72	receiver *testStatusChangeReceiver
73	abstractService
74}
75
76func newTestService(name string) *testService {
77	receiver := newTestStatusChangeReceiver(name)
78	ts := &testService{receiver: receiver}
79	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
80	return ts
81}
82
83func (ts *testService) Start() { ts.UpdateStatus(serviceStarting, nil) }
84func (ts *testService) Stop()  { ts.UpdateStatus(serviceTerminating, nil) }
85
86func (ts *testService) UpdateStatus(targetStatus serviceStatus, err error) {
87	ts.mu.Lock()
88	defer ts.mu.Unlock()
89	ts.unsafeUpdateStatus(targetStatus, err)
90}
91
92func TestServiceUpdateStatusIsLinear(t *testing.T) {
93	err1 := errors.New("error1")
94	err2 := errors.New("error2")
95
96	service := newTestService("service")
97	service.UpdateStatus(serviceStarting, nil)
98	service.receiver.VerifyStatus(t, serviceStarting)
99
100	service.UpdateStatus(serviceActive, nil)
101	service.UpdateStatus(serviceActive, nil)
102	service.receiver.VerifyStatus(t, serviceActive)
103
104	service.UpdateStatus(serviceTerminating, err1)
105	service.UpdateStatus(serviceStarting, nil)
106	service.UpdateStatus(serviceTerminating, nil)
107	service.receiver.VerifyStatus(t, serviceTerminating)
108
109	service.UpdateStatus(serviceTerminated, err2)
110	service.UpdateStatus(serviceTerminated, nil)
111	service.receiver.VerifyStatus(t, serviceTerminated)
112
113	// Verify that the first error is not clobbered by the second.
114	if got, want := service.Error(), err1; !test.ErrorEqual(got, want) {
115		t.Errorf("service.Error(): got (%v), want (%v)", got, want)
116	}
117}
118
119func TestServiceCheckServiceStatus(t *testing.T) {
120	for _, tc := range []struct {
121		status  serviceStatus
122		wantErr error
123	}{
124		{
125			status:  serviceUninitialized,
126			wantErr: ErrServiceUninitialized,
127		},
128		{
129			status:  serviceStarting,
130			wantErr: ErrServiceStarting,
131		},
132		{
133			status:  serviceActive,
134			wantErr: nil,
135		},
136		{
137			status:  serviceTerminating,
138			wantErr: ErrServiceStopped,
139		},
140		{
141			status:  serviceTerminated,
142			wantErr: ErrServiceStopped,
143		},
144	} {
145		t.Run(fmt.Sprintf("Status=%v", tc.status), func(t *testing.T) {
146			s := newTestService("service")
147			s.UpdateStatus(tc.status, nil)
148			if gotErr := s.unsafeCheckServiceStatus(); !test.ErrorEqual(gotErr, tc.wantErr) {
149				t.Errorf("service.unsafeCheckServiceStatus(): got (%v), want (%v)", gotErr, tc.wantErr)
150			}
151		})
152	}
153}
154
155func TestServiceAddRemoveStatusChangeReceiver(t *testing.T) {
156	receiver1 := newTestStatusChangeReceiver("receiver1")
157	receiver2 := newTestStatusChangeReceiver("receiver2")
158	receiver3 := newTestStatusChangeReceiver("receiver3")
159
160	service := new(testService)
161	service.AddStatusChangeReceiver(receiver1.Handle(), receiver1.OnStatusChange)
162	service.AddStatusChangeReceiver(receiver2.Handle(), receiver2.OnStatusChange)
163	service.AddStatusChangeReceiver(receiver3.Handle(), receiver3.OnStatusChange)
164
165	t.Run("All receivers", func(t *testing.T) {
166		service.UpdateStatus(serviceActive, nil)
167
168		receiver1.VerifyStatus(t, serviceActive)
169		receiver2.VerifyStatus(t, serviceActive)
170		receiver3.VerifyStatus(t, serviceActive)
171	})
172
173	t.Run("receiver1 removed", func(t *testing.T) {
174		service.RemoveStatusChangeReceiver(receiver1.Handle())
175		service.UpdateStatus(serviceTerminating, nil)
176
177		receiver1.VerifyNoStatusChanges(t)
178		receiver2.VerifyStatus(t, serviceTerminating)
179		receiver3.VerifyStatus(t, serviceTerminating)
180	})
181
182	t.Run("receiver2 removed", func(t *testing.T) {
183		service.RemoveStatusChangeReceiver(receiver2.Handle())
184		service.UpdateStatus(serviceTerminated, nil)
185
186		receiver1.VerifyNoStatusChanges(t)
187		receiver2.VerifyNoStatusChanges(t)
188		receiver3.VerifyStatus(t, serviceTerminated)
189	})
190}
191
192type testCompositeService struct {
193	receiver *testStatusChangeReceiver
194	compositeService
195}
196
197func newTestCompositeService(name string) *testCompositeService {
198	receiver := newTestStatusChangeReceiver(name)
199	ts := &testCompositeService{receiver: receiver}
200	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
201	ts.init()
202	return ts
203}
204
205func (ts *testCompositeService) AddServices(services ...service) error {
206	ts.mu.Lock()
207	defer ts.mu.Unlock()
208	return ts.unsafeAddServices(services...)
209}
210
211func (ts *testCompositeService) RemoveService(service service) {
212	ts.mu.Lock()
213	defer ts.mu.Unlock()
214	ts.unsafeRemoveService(service)
215}
216
217func (ts *testCompositeService) DependenciesLen() int {
218	ts.mu.Lock()
219	defer ts.mu.Unlock()
220	return len(ts.dependencies)
221}
222
223func (ts *testCompositeService) RemovedLen() int {
224	ts.mu.Lock()
225	defer ts.mu.Unlock()
226	return len(ts.removed)
227}
228
229func TestCompositeServiceNormalStop(t *testing.T) {
230	child1 := newTestService("child1")
231	child2 := newTestService("child2")
232	child3 := newTestService("child3")
233	parent := newTestCompositeService("parent")
234	if err := parent.AddServices(child1, child2); err != nil {
235		t.Errorf("AddServices() got err: %v", err)
236	}
237
238	t.Run("Starting", func(t *testing.T) {
239		wantState := serviceUninitialized
240		if child1.Status() != wantState {
241			t.Errorf("child1: current service status: got %d, want %d", child1.Status(), wantState)
242		}
243		if child2.Status() != wantState {
244			t.Errorf("child2: current service status: got %d, want %d", child2.Status(), wantState)
245		}
246
247		parent.Start()
248
249		child1.receiver.VerifyStatus(t, serviceStarting)
250		child2.receiver.VerifyStatus(t, serviceStarting)
251		parent.receiver.VerifyStatus(t, serviceStarting)
252
253		// child3 is added after Start() and should be automatically started.
254		if child3.Status() != wantState {
255			t.Errorf("child3: current service status: got %d, want %d", child3.Status(), wantState)
256		}
257		if err := parent.AddServices(child3); err != nil {
258			t.Errorf("AddServices() got err: %v", err)
259		}
260		child3.receiver.VerifyStatus(t, serviceStarting)
261	})
262
263	t.Run("Active", func(t *testing.T) {
264		// parent service is active once all children are active.
265		child1.UpdateStatus(serviceActive, nil)
266		child2.UpdateStatus(serviceActive, nil)
267		parent.receiver.VerifyNoStatusChanges(t)
268		child3.UpdateStatus(serviceActive, nil)
269
270		child1.receiver.VerifyStatus(t, serviceActive)
271		child2.receiver.VerifyStatus(t, serviceActive)
272		child3.receiver.VerifyStatus(t, serviceActive)
273		parent.receiver.VerifyStatus(t, serviceActive)
274		if err := parent.WaitStarted(); err != nil {
275			t.Errorf("compositeService.WaitStarted() got err: %v", err)
276		}
277	})
278
279	t.Run("Stopping", func(t *testing.T) {
280		parent.Stop()
281
282		child1.receiver.VerifyStatus(t, serviceTerminating)
283		child2.receiver.VerifyStatus(t, serviceTerminating)
284		child3.receiver.VerifyStatus(t, serviceTerminating)
285		parent.receiver.VerifyStatus(t, serviceTerminating)
286
287		// parent service is terminated once all children have terminated.
288		child1.UpdateStatus(serviceTerminated, nil)
289		child2.UpdateStatus(serviceTerminated, nil)
290		parent.receiver.VerifyNoStatusChanges(t)
291		child3.UpdateStatus(serviceTerminated, nil)
292
293		child1.receiver.VerifyStatus(t, serviceTerminated)
294		child2.receiver.VerifyStatus(t, serviceTerminated)
295		child3.receiver.VerifyStatus(t, serviceTerminated)
296		parent.receiver.VerifyStatus(t, serviceTerminated)
297		if err := parent.WaitStopped(); err != nil {
298			t.Errorf("compositeService.WaitStopped() got err: %v", err)
299		}
300	})
301}
302
303func TestCompositeServiceErrorDuringStartup(t *testing.T) {
304	child1 := newTestService("child1")
305	child2 := newTestService("child2")
306	parent := newTestCompositeService("parent")
307	if err := parent.AddServices(child1, child2); err != nil {
308		t.Errorf("AddServices() got err: %v", err)
309	}
310
311	t.Run("Starting", func(t *testing.T) {
312		parent.Start()
313
314		parent.receiver.VerifyStatus(t, serviceStarting)
315		child1.receiver.VerifyStatus(t, serviceStarting)
316		child2.receiver.VerifyStatus(t, serviceStarting)
317	})
318
319	t.Run("Terminating", func(t *testing.T) {
320		// child1 now errors.
321		wantErr := errors.New("err during startup")
322		child1.UpdateStatus(serviceTerminated, wantErr)
323		child1.receiver.VerifyStatus(t, serviceTerminated)
324
325		// This causes parent and child2 to start terminating.
326		parent.receiver.VerifyStatus(t, serviceTerminating)
327		child2.receiver.VerifyStatus(t, serviceTerminating)
328
329		// parent has terminated once child2 has terminated.
330		child2.UpdateStatus(serviceTerminated, nil)
331		child2.receiver.VerifyStatus(t, serviceTerminated)
332		parent.receiver.VerifyStatus(t, serviceTerminated)
333		if gotErr := parent.WaitStarted(); !test.ErrorEqual(gotErr, wantErr) {
334			t.Errorf("compositeService.WaitStarted() got err: (%v), want err: (%v)", gotErr, wantErr)
335		}
336	})
337}
338
339func TestCompositeServiceErrorWhileActive(t *testing.T) {
340	child1 := newTestService("child1")
341	child2 := newTestService("child2")
342	parent := newTestCompositeService("parent")
343	if err := parent.AddServices(child1, child2); err != nil {
344		t.Errorf("AddServices() got err: %v", err)
345	}
346
347	t.Run("Starting", func(t *testing.T) {
348		parent.Start()
349
350		child1.receiver.VerifyStatus(t, serviceStarting)
351		child2.receiver.VerifyStatus(t, serviceStarting)
352		parent.receiver.VerifyStatus(t, serviceStarting)
353	})
354
355	t.Run("Active", func(t *testing.T) {
356		child1.UpdateStatus(serviceActive, nil)
357		child2.UpdateStatus(serviceActive, nil)
358
359		child1.receiver.VerifyStatus(t, serviceActive)
360		child2.receiver.VerifyStatus(t, serviceActive)
361		parent.receiver.VerifyStatus(t, serviceActive)
362		if err := parent.WaitStarted(); err != nil {
363			t.Errorf("compositeService.WaitStarted() got err: %v", err)
364		}
365	})
366
367	t.Run("Terminating", func(t *testing.T) {
368		// child2 now errors.
369		wantErr := errors.New("err while active")
370		child2.UpdateStatus(serviceTerminating, wantErr)
371		child2.receiver.VerifyStatus(t, serviceTerminating)
372
373		// This causes parent and child1 to start terminating.
374		child1.receiver.VerifyStatus(t, serviceTerminating)
375		parent.receiver.VerifyStatus(t, serviceTerminating)
376
377		// parent has terminated once both children have terminated.
378		child1.UpdateStatus(serviceTerminated, nil)
379		child2.UpdateStatus(serviceTerminated, nil)
380		child1.receiver.VerifyStatus(t, serviceTerminated)
381		child2.receiver.VerifyStatus(t, serviceTerminated)
382		parent.receiver.VerifyStatus(t, serviceTerminated)
383		if gotErr := parent.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
384			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
385		}
386	})
387}
388
389func TestCompositeServiceRemoveService(t *testing.T) {
390	child1 := newTestService("child1")
391	child2 := newTestService("child2")
392	parent := newTestCompositeService("parent")
393	if err := parent.AddServices(child1, child2); err != nil {
394		t.Errorf("AddServices() got err: %v", err)
395	}
396
397	t.Run("Starting", func(t *testing.T) {
398		parent.Start()
399
400		child1.receiver.VerifyStatus(t, serviceStarting)
401		child2.receiver.VerifyStatus(t, serviceStarting)
402		parent.receiver.VerifyStatus(t, serviceStarting)
403	})
404
405	t.Run("Active", func(t *testing.T) {
406		child1.UpdateStatus(serviceActive, nil)
407		child2.UpdateStatus(serviceActive, nil)
408
409		child1.receiver.VerifyStatus(t, serviceActive)
410		child2.receiver.VerifyStatus(t, serviceActive)
411		parent.receiver.VerifyStatus(t, serviceActive)
412	})
413
414	t.Run("Remove service", func(t *testing.T) {
415		// Removing child1 should stop it, but leave everything else active.
416		parent.RemoveService(child1)
417
418		if got, want := parent.DependenciesLen(), 1; got != want {
419			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
420		}
421		if got, want := parent.RemovedLen(), 1; got != want {
422			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
423		}
424
425		child1.receiver.VerifyStatus(t, serviceTerminating)
426		child2.receiver.VerifyNoStatusChanges(t)
427		parent.receiver.VerifyNoStatusChanges(t)
428
429		// After child1 has terminated, it should be removed.
430		child1.UpdateStatus(serviceTerminated, nil)
431
432		child1.receiver.VerifyStatus(t, serviceTerminated)
433		child2.receiver.VerifyNoStatusChanges(t)
434		parent.receiver.VerifyNoStatusChanges(t)
435	})
436
437	t.Run("Terminating", func(t *testing.T) {
438		// Now stop the composite service.
439		parent.Stop()
440
441		child2.receiver.VerifyStatus(t, serviceTerminating)
442		parent.receiver.VerifyStatus(t, serviceTerminating)
443
444		child2.UpdateStatus(serviceTerminated, nil)
445
446		child2.receiver.VerifyStatus(t, serviceTerminated)
447		parent.receiver.VerifyStatus(t, serviceTerminated)
448		if err := parent.WaitStopped(); err != nil {
449			t.Errorf("compositeService.WaitStopped() got err: %v", err)
450		}
451
452		if got, want := parent.DependenciesLen(), 1; got != want {
453			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
454		}
455		if got, want := parent.RemovedLen(), 0; got != want {
456			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
457		}
458	})
459}
460
461func TestCompositeServiceTree(t *testing.T) {
462	leaf1 := newTestService("leaf1")
463	leaf2 := newTestService("leaf2")
464	intermediate1 := newTestCompositeService("intermediate1")
465	if err := intermediate1.AddServices(leaf1, leaf2); err != nil {
466		t.Errorf("intermediate1.AddServices() got err: %v", err)
467	}
468
469	leaf3 := newTestService("leaf3")
470	leaf4 := newTestService("leaf4")
471	intermediate2 := newTestCompositeService("intermediate2")
472	if err := intermediate2.AddServices(leaf3, leaf4); err != nil {
473		t.Errorf("intermediate2.AddServices() got err: %v", err)
474	}
475
476	root := newTestCompositeService("root")
477	if err := root.AddServices(intermediate1, intermediate2); err != nil {
478		t.Errorf("root.AddServices() got err: %v", err)
479	}
480	wantErr := errors.New("fail")
481
482	t.Run("Starting", func(t *testing.T) {
483		// Start trickles down the tree.
484		root.Start()
485
486		leaf1.receiver.VerifyStatus(t, serviceStarting)
487		leaf2.receiver.VerifyStatus(t, serviceStarting)
488		leaf3.receiver.VerifyStatus(t, serviceStarting)
489		leaf4.receiver.VerifyStatus(t, serviceStarting)
490		intermediate1.receiver.VerifyStatus(t, serviceStarting)
491		intermediate2.receiver.VerifyStatus(t, serviceStarting)
492		root.receiver.VerifyStatus(t, serviceStarting)
493	})
494
495	t.Run("Active", func(t *testing.T) {
496		// serviceActive notification trickles up the tree.
497		leaf1.UpdateStatus(serviceActive, nil)
498		leaf2.UpdateStatus(serviceActive, nil)
499		leaf3.UpdateStatus(serviceActive, nil)
500		leaf4.UpdateStatus(serviceActive, nil)
501
502		leaf1.receiver.VerifyStatus(t, serviceActive)
503		leaf2.receiver.VerifyStatus(t, serviceActive)
504		leaf3.receiver.VerifyStatus(t, serviceActive)
505		leaf4.receiver.VerifyStatus(t, serviceActive)
506		intermediate1.receiver.VerifyStatus(t, serviceActive)
507		intermediate2.receiver.VerifyStatus(t, serviceActive)
508		root.receiver.VerifyStatus(t, serviceActive)
509		if err := root.WaitStarted(); err != nil {
510			t.Errorf("compositeService.WaitStarted() got err: %v", err)
511		}
512	})
513
514	t.Run("Leaf fails", func(t *testing.T) {
515		leaf1.UpdateStatus(serviceTerminated, wantErr)
516		leaf1.receiver.VerifyStatus(t, serviceTerminated)
517
518		// Leaf service failure should trickle up the tree and across to all other
519		// leaves, causing them all to start terminating.
520		leaf2.receiver.VerifyStatus(t, serviceTerminating)
521		leaf3.receiver.VerifyStatus(t, serviceTerminating)
522		leaf4.receiver.VerifyStatus(t, serviceTerminating)
523		intermediate1.receiver.VerifyStatus(t, serviceTerminating)
524		intermediate2.receiver.VerifyStatus(t, serviceTerminating)
525		root.receiver.VerifyStatus(t, serviceTerminating)
526	})
527
528	t.Run("Terminated", func(t *testing.T) {
529		// serviceTerminated notification trickles up the tree.
530		leaf2.UpdateStatus(serviceTerminated, nil)
531		leaf3.UpdateStatus(serviceTerminated, nil)
532		leaf4.UpdateStatus(serviceTerminated, nil)
533
534		leaf2.receiver.VerifyStatus(t, serviceTerminated)
535		leaf3.receiver.VerifyStatus(t, serviceTerminated)
536		leaf4.receiver.VerifyStatus(t, serviceTerminated)
537		intermediate1.receiver.VerifyStatus(t, serviceTerminated)
538		intermediate2.receiver.VerifyStatus(t, serviceTerminated)
539		root.receiver.VerifyStatus(t, serviceTerminated)
540
541		if gotErr := root.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
542			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
543		}
544	})
545}
546
547func TestCompositeServiceAddServicesErrors(t *testing.T) {
548	child1 := newTestService("child1")
549	parent := newTestCompositeService("parent")
550	if err := parent.AddServices(child1); err != nil {
551		t.Errorf("AddServices(child1) got err: %v", err)
552	}
553
554	child2 := newTestService("child2")
555	child2.Start()
556	if gotErr, wantErr := parent.AddServices(child2), errChildServiceStarted; !test.ErrorEqual(gotErr, wantErr) {
557		t.Errorf("AddServices(child2) got err: (%v), want err: (%v)", gotErr, wantErr)
558	}
559
560	parent.Stop()
561	child3 := newTestService("child3")
562	if gotErr, wantErr := parent.AddServices(child3), ErrServiceStopped; !test.ErrorEqual(gotErr, wantErr) {
563		t.Errorf("AddServices(child3) got err: (%v), want err: (%v)", gotErr, wantErr)
564	}
565}
566