1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"errors"
18	"fmt"
19	"testing"
20	"time"
21
22	"cloud.google.com/go/pubsublite/internal/test"
23)
24
25const receiveStatusTimeout = 5 * time.Second
26
27type testStatusChangeReceiver struct {
28	// Status change notifications are fired asynchronously, so a channel receives
29	// the statuses.
30	statusC    chan serviceStatus
31	lastStatus serviceStatus
32	name       string
33}
34
35func newTestStatusChangeReceiver(name string) *testStatusChangeReceiver {
36	return &testStatusChangeReceiver{
37		statusC: make(chan serviceStatus, 1),
38		name:    name,
39	}
40}
41
42func (sr *testStatusChangeReceiver) Handle() interface{} { return sr }
43
44func (sr *testStatusChangeReceiver) OnStatusChange(handle serviceHandle, status serviceStatus, err error) {
45	sr.statusC <- status
46}
47
48func (sr *testStatusChangeReceiver) VerifyStatus(t *testing.T, want serviceStatus) {
49	select {
50	case status := <-sr.statusC:
51		if status <= sr.lastStatus {
52			t.Errorf("%s: Duplicate service status: %d, last status: %d", sr.name, status, sr.lastStatus)
53		}
54		if status != want {
55			t.Errorf("%s: Got service status: %d, want: %d", sr.name, status, want)
56		}
57		sr.lastStatus = status
58	case <-time.After(receiveStatusTimeout):
59		t.Errorf("%s: Did not receive status within %v", sr.name, receiveStatusTimeout)
60	}
61}
62
63func (sr *testStatusChangeReceiver) VerifyNoStatusChanges(t *testing.T) {
64	select {
65	case status := <-sr.statusC:
66		t.Errorf("%s: Unexpected service status: %d", sr.name, status)
67	default:
68	}
69}
70
71type testService struct {
72	receiver *testStatusChangeReceiver
73	abstractService
74}
75
76func newTestService(name string) *testService {
77	receiver := newTestStatusChangeReceiver(name)
78	ts := &testService{receiver: receiver}
79	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
80	return ts
81}
82
83func (ts *testService) Start() { ts.UpdateStatus(serviceStarting, nil) }
84func (ts *testService) Stop()  { ts.UpdateStatus(serviceTerminating, nil) }
85
86func (ts *testService) UpdateStatus(targetStatus serviceStatus, err error) {
87	ts.mu.Lock()
88	defer ts.mu.Unlock()
89	ts.unsafeUpdateStatus(targetStatus, err)
90}
91
92func TestServiceUpdateStatusIsLinear(t *testing.T) {
93	err1 := errors.New("error1")
94	err2 := errors.New("error2")
95
96	service := newTestService("service")
97	service.UpdateStatus(serviceStarting, nil)
98	service.receiver.VerifyStatus(t, serviceStarting)
99
100	service.UpdateStatus(serviceActive, nil)
101	service.UpdateStatus(serviceActive, nil)
102	service.receiver.VerifyStatus(t, serviceActive)
103
104	service.UpdateStatus(serviceTerminating, err1)
105	service.UpdateStatus(serviceStarting, nil)
106	service.UpdateStatus(serviceTerminating, nil)
107	service.receiver.VerifyStatus(t, serviceTerminating)
108
109	service.UpdateStatus(serviceTerminated, err2)
110	service.UpdateStatus(serviceTerminated, nil)
111	service.receiver.VerifyStatus(t, serviceTerminated)
112
113	// Verify that the first error is not clobbered by the second.
114	if got, want := service.Error(), err1; !test.ErrorEqual(got, want) {
115		t.Errorf("service.Error(): got (%v), want (%v)", got, want)
116	}
117}
118
119func TestServiceCheckServiceStatus(t *testing.T) {
120	for _, tc := range []struct {
121		status  serviceStatus
122		wantErr error
123	}{
124		{
125			status:  serviceUninitialized,
126			wantErr: ErrServiceUninitialized,
127		},
128		{
129			status:  serviceStarting,
130			wantErr: ErrServiceStarting,
131		},
132		{
133			status:  serviceActive,
134			wantErr: nil,
135		},
136		{
137			status:  serviceTerminating,
138			wantErr: ErrServiceStopped,
139		},
140		{
141			status:  serviceTerminated,
142			wantErr: ErrServiceStopped,
143		},
144	} {
145		t.Run(fmt.Sprintf("Status=%v", tc.status), func(t *testing.T) {
146			s := newTestService("service")
147			s.UpdateStatus(tc.status, nil)
148			if gotErr := s.unsafeCheckServiceStatus(); !test.ErrorEqual(gotErr, tc.wantErr) {
149				t.Errorf("service.unsafeCheckServiceStatus(): got (%v), want (%v)", gotErr, tc.wantErr)
150			}
151		})
152	}
153}
154
155func TestServiceAddRemoveStatusChangeReceiver(t *testing.T) {
156	receiver1 := newTestStatusChangeReceiver("receiver1")
157	receiver2 := newTestStatusChangeReceiver("receiver2")
158	receiver3 := newTestStatusChangeReceiver("receiver3")
159
160	service := new(testService)
161	service.AddStatusChangeReceiver(receiver1.Handle(), receiver1.OnStatusChange)
162	service.AddStatusChangeReceiver(receiver2.Handle(), receiver2.OnStatusChange)
163	service.AddStatusChangeReceiver(receiver3.Handle(), receiver3.OnStatusChange)
164
165	t.Run("All receivers", func(t *testing.T) {
166		service.UpdateStatus(serviceActive, nil)
167
168		receiver1.VerifyStatus(t, serviceActive)
169		receiver2.VerifyStatus(t, serviceActive)
170		receiver3.VerifyStatus(t, serviceActive)
171	})
172
173	t.Run("receiver1 removed", func(t *testing.T) {
174		service.RemoveStatusChangeReceiver(receiver1.Handle())
175		service.UpdateStatus(serviceTerminating, nil)
176
177		receiver1.VerifyNoStatusChanges(t)
178		receiver2.VerifyStatus(t, serviceTerminating)
179		receiver3.VerifyStatus(t, serviceTerminating)
180	})
181
182	t.Run("receiver2 removed", func(t *testing.T) {
183		service.RemoveStatusChangeReceiver(receiver2.Handle())
184		service.UpdateStatus(serviceTerminated, nil)
185
186		receiver1.VerifyNoStatusChanges(t)
187		receiver2.VerifyNoStatusChanges(t)
188		receiver3.VerifyStatus(t, serviceTerminated)
189	})
190}
191
192type testCompositeService struct {
193	receiver *testStatusChangeReceiver
194	compositeService
195}
196
197func newTestCompositeService(name string) *testCompositeService {
198	receiver := newTestStatusChangeReceiver(name)
199	ts := &testCompositeService{receiver: receiver}
200	ts.AddStatusChangeReceiver(receiver.Handle(), receiver.OnStatusChange)
201	ts.init()
202	return ts
203}
204
205func (ts *testCompositeService) AddServices(services ...service) error {
206	ts.mu.Lock()
207	defer ts.mu.Unlock()
208	return ts.unsafeAddServices(services...)
209}
210
211func (ts *testCompositeService) RemoveService(service service) {
212	ts.mu.Lock()
213	defer ts.mu.Unlock()
214	ts.unsafeRemoveService(service)
215}
216
217func (ts *testCompositeService) DependenciesLen() int {
218	ts.mu.Lock()
219	defer ts.mu.Unlock()
220	return len(ts.dependencies)
221}
222
223func (ts *testCompositeService) RemovedLen() int {
224	ts.mu.Lock()
225	defer ts.mu.Unlock()
226	return len(ts.removed)
227}
228
229func TestCompositeServiceNormalStop(t *testing.T) {
230	child1 := newTestService("child1")
231	child2 := newTestService("child2")
232	child3 := newTestService("child3")
233	parent := newTestCompositeService("parent")
234	if err := parent.AddServices(child1, child2); err != nil {
235		t.Errorf("AddServices() got err: %v", err)
236	}
237
238	t.Run("Starting", func(t *testing.T) {
239		wantState := serviceUninitialized
240		if child1.Status() != wantState {
241			t.Errorf("child1: current service status: got %d, want %d", child1.Status(), wantState)
242		}
243		if child2.Status() != wantState {
244			t.Errorf("child2: current service status: got %d, want %d", child2.Status(), wantState)
245		}
246
247		parent.Start()
248
249		child1.receiver.VerifyStatus(t, serviceStarting)
250		child2.receiver.VerifyStatus(t, serviceStarting)
251		parent.receiver.VerifyStatus(t, serviceStarting)
252
253		// child3 is added after Start() and should be automatically started.
254		if child3.Status() != wantState {
255			t.Errorf("child3: current service status: got %d, want %d", child3.Status(), wantState)
256		}
257		if err := parent.AddServices(child3); err != nil {
258			t.Errorf("AddServices() got err: %v", err)
259		}
260		child3.receiver.VerifyStatus(t, serviceStarting)
261	})
262
263	t.Run("Active", func(t *testing.T) {
264		// parent service is active once all children are active.
265		child1.UpdateStatus(serviceActive, nil)
266		child2.UpdateStatus(serviceActive, nil)
267		parent.receiver.VerifyNoStatusChanges(t)
268		child3.UpdateStatus(serviceActive, nil)
269
270		child1.receiver.VerifyStatus(t, serviceActive)
271		child2.receiver.VerifyStatus(t, serviceActive)
272		child3.receiver.VerifyStatus(t, serviceActive)
273		parent.receiver.VerifyStatus(t, serviceActive)
274		if err := parent.WaitStarted(); err != nil {
275			t.Errorf("compositeService.WaitStarted() got err: %v", err)
276		}
277	})
278
279	t.Run("Stopping", func(t *testing.T) {
280		parent.Stop()
281
282		child1.receiver.VerifyStatus(t, serviceTerminating)
283		child2.receiver.VerifyStatus(t, serviceTerminating)
284		child3.receiver.VerifyStatus(t, serviceTerminating)
285		parent.receiver.VerifyStatus(t, serviceTerminating)
286
287		// parent service is terminated once all children have terminated.
288		child1.UpdateStatus(serviceTerminated, nil)
289		child2.UpdateStatus(serviceTerminated, nil)
290		parent.receiver.VerifyNoStatusChanges(t)
291		child3.UpdateStatus(serviceTerminated, nil)
292
293		child1.receiver.VerifyStatus(t, serviceTerminated)
294		child2.receiver.VerifyStatus(t, serviceTerminated)
295		child3.receiver.VerifyStatus(t, serviceTerminated)
296		parent.receiver.VerifyStatus(t, serviceTerminated)
297		if err := parent.WaitStopped(); err != nil {
298			t.Errorf("compositeService.WaitStopped() got err: %v", err)
299		}
300	})
301}
302
303func TestCompositeServiceErrorDuringStartup(t *testing.T) {
304	child1 := newTestService("child1")
305	child2 := newTestService("child2")
306	parent := newTestCompositeService("parent")
307	if err := parent.AddServices(child1, child2); err != nil {
308		t.Errorf("AddServices() got err: %v", err)
309	}
310
311	t.Run("Starting", func(t *testing.T) {
312		parent.Start()
313
314		parent.receiver.VerifyStatus(t, serviceStarting)
315		child1.receiver.VerifyStatus(t, serviceStarting)
316		child2.receiver.VerifyStatus(t, serviceStarting)
317	})
318
319	t.Run("Terminating", func(t *testing.T) {
320		// child1 now errors.
321		wantErr := errors.New("err during startup")
322		child1.UpdateStatus(serviceTerminated, wantErr)
323		child1.receiver.VerifyStatus(t, serviceTerminated)
324
325		// This causes parent and child2 to start terminating.
326		parent.receiver.VerifyStatus(t, serviceTerminating)
327		child2.receiver.VerifyStatus(t, serviceTerminating)
328
329		// parent has terminated once child2 has terminated.
330		child2.UpdateStatus(serviceTerminated, nil)
331		child2.receiver.VerifyStatus(t, serviceTerminated)
332		parent.receiver.VerifyStatus(t, serviceTerminated)
333		if gotErr := parent.WaitStarted(); !test.ErrorEqual(gotErr, wantErr) {
334			t.Errorf("compositeService.WaitStarted() got err: (%v), want err: (%v)", gotErr, wantErr)
335		}
336	})
337}
338
339func TestCompositeServiceErrorWhileActive(t *testing.T) {
340	child1 := newTestService("child1")
341	child2 := newTestService("child2")
342	parent := newTestCompositeService("parent")
343	if err := parent.AddServices(child1, child2); err != nil {
344		t.Errorf("AddServices() got err: %v", err)
345	}
346
347	t.Run("Starting", func(t *testing.T) {
348		parent.Start()
349
350		child1.receiver.VerifyStatus(t, serviceStarting)
351		child2.receiver.VerifyStatus(t, serviceStarting)
352		parent.receiver.VerifyStatus(t, serviceStarting)
353	})
354
355	t.Run("Active", func(t *testing.T) {
356		child1.UpdateStatus(serviceActive, nil)
357		child2.UpdateStatus(serviceActive, nil)
358
359		child1.receiver.VerifyStatus(t, serviceActive)
360		child2.receiver.VerifyStatus(t, serviceActive)
361		parent.receiver.VerifyStatus(t, serviceActive)
362		if err := parent.WaitStarted(); err != nil {
363			t.Errorf("compositeService.WaitStarted() got err: %v", err)
364		}
365	})
366
367	t.Run("Terminating", func(t *testing.T) {
368		// child2 now errors.
369		wantErr := errors.New("err while active")
370		child2.UpdateStatus(serviceTerminating, wantErr)
371		child2.receiver.VerifyStatus(t, serviceTerminating)
372
373		// This causes parent and child1 to start terminating.
374		child1.receiver.VerifyStatus(t, serviceTerminating)
375		parent.receiver.VerifyStatus(t, serviceTerminating)
376
377		// parent has terminated once both children have terminated.
378		child1.UpdateStatus(serviceTerminated, nil)
379		child2.UpdateStatus(serviceTerminated, nil)
380		child1.receiver.VerifyStatus(t, serviceTerminated)
381		child2.receiver.VerifyStatus(t, serviceTerminated)
382		parent.receiver.VerifyStatus(t, serviceTerminated)
383		if gotErr := parent.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
384			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
385		}
386	})
387}
388
389func TestCompositeServiceRemoveService(t *testing.T) {
390	child1 := newTestService("child1")
391	child2 := newTestService("child2")
392	parent := newTestCompositeService("parent")
393	if err := parent.AddServices(child1, child2); err != nil {
394		t.Errorf("AddServices() got err: %v", err)
395	}
396
397	t.Run("Starting", func(t *testing.T) {
398		parent.Start()
399
400		child1.receiver.VerifyStatus(t, serviceStarting)
401		child2.receiver.VerifyStatus(t, serviceStarting)
402		parent.receiver.VerifyStatus(t, serviceStarting)
403	})
404
405	t.Run("Active", func(t *testing.T) {
406		child1.UpdateStatus(serviceActive, nil)
407		child2.UpdateStatus(serviceActive, nil)
408
409		child1.receiver.VerifyStatus(t, serviceActive)
410		child2.receiver.VerifyStatus(t, serviceActive)
411		parent.receiver.VerifyStatus(t, serviceActive)
412	})
413
414	t.Run("Remove service", func(t *testing.T) {
415		if got, want := parent.DependenciesLen(), 2; got != want {
416			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
417		}
418		if got, want := parent.RemovedLen(), 0; got != want {
419			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
420		}
421
422		// Removing child1 should stop it, but leave everything else active.
423		parent.RemoveService(child1)
424
425		if got, want := parent.DependenciesLen(), 1; got != want {
426			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
427		}
428		if got, want := parent.RemovedLen(), 1; got != want {
429			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
430		}
431
432		child1.receiver.VerifyStatus(t, serviceTerminating)
433		child2.receiver.VerifyNoStatusChanges(t)
434		parent.receiver.VerifyNoStatusChanges(t)
435
436		// After child1 has terminated, it should be removed.
437		child1.UpdateStatus(serviceTerminated, nil)
438
439		child1.receiver.VerifyStatus(t, serviceTerminated)
440		child2.receiver.VerifyNoStatusChanges(t)
441		parent.receiver.VerifyNoStatusChanges(t)
442		if got, want := parent.Status(), serviceActive; got != want {
443			t.Errorf("compositeService.Status() got %v, want %v", got, want)
444		}
445	})
446
447	t.Run("Terminating", func(t *testing.T) {
448		// Now stop the composite service.
449		parent.Stop()
450
451		child2.receiver.VerifyStatus(t, serviceTerminating)
452		parent.receiver.VerifyStatus(t, serviceTerminating)
453
454		child2.UpdateStatus(serviceTerminated, nil)
455
456		child2.receiver.VerifyStatus(t, serviceTerminated)
457		parent.receiver.VerifyStatus(t, serviceTerminated)
458		if err := parent.WaitStopped(); err != nil {
459			t.Errorf("compositeService.WaitStopped() got err: %v", err)
460		}
461
462		if got, want := parent.DependenciesLen(), 1; got != want {
463			t.Errorf("compositeService.dependencies: got len %d, want %d", got, want)
464		}
465		if got, want := parent.RemovedLen(), 0; got != want {
466			t.Errorf("compositeService.removed: got len %d, want %d", got, want)
467		}
468	})
469}
470
471func TestCompositeServiceTree(t *testing.T) {
472	leaf1 := newTestService("leaf1")
473	leaf2 := newTestService("leaf2")
474	intermediate1 := newTestCompositeService("intermediate1")
475	if err := intermediate1.AddServices(leaf1, leaf2); err != nil {
476		t.Errorf("intermediate1.AddServices() got err: %v", err)
477	}
478
479	leaf3 := newTestService("leaf3")
480	leaf4 := newTestService("leaf4")
481	intermediate2 := newTestCompositeService("intermediate2")
482	if err := intermediate2.AddServices(leaf3, leaf4); err != nil {
483		t.Errorf("intermediate2.AddServices() got err: %v", err)
484	}
485
486	root := newTestCompositeService("root")
487	if err := root.AddServices(intermediate1, intermediate2); err != nil {
488		t.Errorf("root.AddServices() got err: %v", err)
489	}
490	wantErr := errors.New("fail")
491
492	t.Run("Starting", func(t *testing.T) {
493		// Start trickles down the tree.
494		root.Start()
495
496		leaf1.receiver.VerifyStatus(t, serviceStarting)
497		leaf2.receiver.VerifyStatus(t, serviceStarting)
498		leaf3.receiver.VerifyStatus(t, serviceStarting)
499		leaf4.receiver.VerifyStatus(t, serviceStarting)
500		intermediate1.receiver.VerifyStatus(t, serviceStarting)
501		intermediate2.receiver.VerifyStatus(t, serviceStarting)
502		root.receiver.VerifyStatus(t, serviceStarting)
503	})
504
505	t.Run("Active", func(t *testing.T) {
506		// serviceActive notification trickles up the tree.
507		leaf1.UpdateStatus(serviceActive, nil)
508		leaf2.UpdateStatus(serviceActive, nil)
509		leaf3.UpdateStatus(serviceActive, nil)
510		leaf4.UpdateStatus(serviceActive, nil)
511
512		leaf1.receiver.VerifyStatus(t, serviceActive)
513		leaf2.receiver.VerifyStatus(t, serviceActive)
514		leaf3.receiver.VerifyStatus(t, serviceActive)
515		leaf4.receiver.VerifyStatus(t, serviceActive)
516		intermediate1.receiver.VerifyStatus(t, serviceActive)
517		intermediate2.receiver.VerifyStatus(t, serviceActive)
518		root.receiver.VerifyStatus(t, serviceActive)
519		if err := root.WaitStarted(); err != nil {
520			t.Errorf("compositeService.WaitStarted() got err: %v", err)
521		}
522	})
523
524	t.Run("Leaf fails", func(t *testing.T) {
525		leaf1.UpdateStatus(serviceTerminated, wantErr)
526		leaf1.receiver.VerifyStatus(t, serviceTerminated)
527
528		// Leaf service failure should trickle up the tree and across to all other
529		// leaves, causing them all to start terminating.
530		leaf2.receiver.VerifyStatus(t, serviceTerminating)
531		leaf3.receiver.VerifyStatus(t, serviceTerminating)
532		leaf4.receiver.VerifyStatus(t, serviceTerminating)
533		intermediate1.receiver.VerifyStatus(t, serviceTerminating)
534		intermediate2.receiver.VerifyStatus(t, serviceTerminating)
535		root.receiver.VerifyStatus(t, serviceTerminating)
536	})
537
538	t.Run("Terminated", func(t *testing.T) {
539		// serviceTerminated notification trickles up the tree.
540		leaf2.UpdateStatus(serviceTerminated, nil)
541		leaf3.UpdateStatus(serviceTerminated, nil)
542		leaf4.UpdateStatus(serviceTerminated, nil)
543
544		leaf2.receiver.VerifyStatus(t, serviceTerminated)
545		leaf3.receiver.VerifyStatus(t, serviceTerminated)
546		leaf4.receiver.VerifyStatus(t, serviceTerminated)
547		intermediate1.receiver.VerifyStatus(t, serviceTerminated)
548		intermediate2.receiver.VerifyStatus(t, serviceTerminated)
549		root.receiver.VerifyStatus(t, serviceTerminated)
550
551		if gotErr := root.WaitStopped(); !test.ErrorEqual(gotErr, wantErr) {
552			t.Errorf("compositeService.WaitStopped() got err: (%v), want err: (%v)", gotErr, wantErr)
553		}
554	})
555}
556
557func TestCompositeServiceAddServicesErrors(t *testing.T) {
558	child1 := newTestService("child1")
559	parent := newTestCompositeService("parent")
560	if err := parent.AddServices(child1); err != nil {
561		t.Errorf("AddServices(child1) got err: %v", err)
562	}
563
564	child2 := newTestService("child2")
565	child2.Start()
566	if gotErr, wantErr := parent.AddServices(child2), errChildServiceStarted; !test.ErrorEqual(gotErr, wantErr) {
567		t.Errorf("AddServices(child2) got err: (%v), want err: (%v)", gotErr, wantErr)
568	}
569
570	parent.Stop()
571	child3 := newTestService("child3")
572	if gotErr, wantErr := parent.AddServices(child3), ErrServiceStopped; !test.ErrorEqual(gotErr, wantErr) {
573		t.Errorf("AddServices(child3) got err: (%v), want err: (%v)", gotErr, wantErr)
574	}
575}
576