1package centrifuge
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"strconv"
8	"strings"
9	"sync"
10	"testing"
11	"time"
12
13	"github.com/centrifugal/centrifuge/internal/prepared"
14
15	"github.com/centrifugal/protocol"
16	"github.com/stretchr/testify/require"
17)
18
19type testTransport struct {
20	mu             sync.Mutex
21	sink           chan []byte
22	closed         bool
23	closeCh        chan struct{}
24	disconnect     *Disconnect
25	protoType      ProtocolType
26	cancelFn       func()
27	unidirectional bool
28	writeErr       error
29}
30
31func newTestTransport(cancelFn func()) *testTransport {
32	return &testTransport{
33		cancelFn:       cancelFn,
34		protoType:      ProtocolTypeJSON,
35		closeCh:        make(chan struct{}),
36		unidirectional: false,
37	}
38}
39
40func (t *testTransport) setProtocolType(pType ProtocolType) {
41	t.protoType = pType
42}
43
44func (t *testTransport) setUnidirectional(uni bool) {
45	t.unidirectional = uni
46}
47
48func (t *testTransport) setSink(sink chan []byte) {
49	t.sink = sink
50}
51
52func (t *testTransport) Write(message []byte) error {
53	if t.writeErr != nil {
54		return t.writeErr
55	}
56	t.mu.Lock()
57	defer t.mu.Unlock()
58	if t.closed {
59		return io.EOF
60	}
61	if t.sink != nil {
62		t.sink <- message
63	}
64	return nil
65}
66
67func (t *testTransport) WriteMany(messages ...[]byte) error {
68	if t.writeErr != nil {
69		return t.writeErr
70	}
71	t.mu.Lock()
72	defer t.mu.Unlock()
73	if t.closed {
74		return io.EOF
75	}
76	for _, buf := range messages {
77		if t.sink != nil {
78			t.sink <- buf
79		}
80	}
81	return nil
82}
83
84func (t *testTransport) Name() string {
85	return transportWebsocket
86}
87
88func (t *testTransport) Protocol() ProtocolType {
89	return t.protoType
90}
91
92func (t *testTransport) Unidirectional() bool {
93	return t.unidirectional
94}
95
96func (t *testTransport) DisabledPushFlags() uint64 {
97	if t.Unidirectional() {
98		return 0
99	}
100	return PushFlagDisconnect
101}
102
103func (t *testTransport) Close(disconnect *Disconnect) error {
104	t.mu.Lock()
105	defer t.mu.Unlock()
106	if t.closed {
107		return nil
108	}
109	t.disconnect = disconnect
110	t.closed = true
111	t.cancelFn()
112	close(t.closeCh)
113	return nil
114}
115
116func TestHub(t *testing.T) {
117	h := newHub()
118	c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {}))
119	require.NoError(t, err)
120	c.user = "test"
121	err = h.remove(c)
122	require.NoError(t, err)
123	err = h.add(c)
124	require.NoError(t, err)
125	conns := h.UserConnections("test")
126	require.Equal(t, 1, len(conns))
127	require.Equal(t, 1, h.NumClients())
128	require.Equal(t, 1, h.NumUsers())
129
130	validUID := c.uid
131	c.uid = "invalid"
132	err = h.remove(c)
133	require.NoError(t, err)
134	require.Len(t, h.UserConnections("test"), 1)
135
136	c.uid = validUID
137	err = h.remove(c)
138	require.NoError(t, err)
139	require.Len(t, h.UserConnections("test"), 0)
140}
141
142func TestHubUnsubscribe(t *testing.T) {
143	n := defaultTestNode()
144	defer func() { _ = n.Shutdown(context.Background()) }()
145
146	client := newTestSubscribedClient(t, n, "42", "test_channel")
147	transport := client.transport.(*testTransport)
148	transport.sink = make(chan []byte, 100)
149
150	// Unsubscribe not existed user.
151	err := n.hub.unsubscribe("1", "test_channel", "")
152	require.NoError(t, err)
153
154	// Unsubscribe subscribed user.
155	err = n.hub.unsubscribe("42", "test_channel", "")
156	require.NoError(t, err)
157	select {
158	case data := <-transport.sink:
159		require.Equal(t, "{\"result\":{\"type\":3,\"channel\":\"test_channel\",\"data\":{}}}", string(data))
160	case <-time.After(2 * time.Second):
161		t.Fatal("no data in sink")
162	}
163	require.Zero(t, n.hub.NumSubscribers("test_channel"))
164}
165
166func TestHubDisconnect(t *testing.T) {
167	n := defaultNodeNoHandlers()
168	defer func() { _ = n.Shutdown(context.Background()) }()
169
170	n.OnConnect(func(client *Client) {
171		client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) {
172			cb(SubscribeReply{}, nil)
173		})
174	})
175
176	client := newTestSubscribedClient(t, n, "42", "test_channel")
177	clientWithReconnect := newTestSubscribedClient(t, n, "24", "test_channel_reconnect")
178	require.Len(t, n.hub.UserConnections("42"), 1)
179	require.Len(t, n.hub.UserConnections("24"), 1)
180	require.Equal(t, 1, n.hub.NumSubscribers("test_channel"))
181	require.Equal(t, 1, n.hub.NumSubscribers("test_channel_reconnect"))
182
183	wg := sync.WaitGroup{}
184	wg.Add(2)
185
186	client.eventHub.disconnectHandler = func(e DisconnectEvent) {
187		defer wg.Done()
188		require.False(t, e.Disconnect.Reconnect)
189	}
190
191	clientWithReconnect.eventHub.disconnectHandler = func(e DisconnectEvent) {
192		defer wg.Done()
193		require.True(t, e.Disconnect.Reconnect)
194	}
195
196	// Disconnect not existed user.
197	err := n.hub.disconnect("1", DisconnectForceNoReconnect, "", nil)
198	require.NoError(t, err)
199
200	// Disconnect subscribed user.
201	err = n.hub.disconnect("42", DisconnectForceNoReconnect, "", nil)
202	require.NoError(t, err)
203	select {
204	case <-client.transport.(*testTransport).closeCh:
205	case <-time.After(2 * time.Second):
206		t.Fatal("no data in sink")
207	}
208	require.Len(t, n.hub.UserConnections("42"), 0)
209	require.Equal(t, 0, n.hub.NumSubscribers("test_channel"))
210
211	// Disconnect subscribed user with reconnect.
212	err = n.hub.disconnect("24", DisconnectForceReconnect, "", nil)
213	require.NoError(t, err)
214	select {
215	case <-clientWithReconnect.transport.(*testTransport).closeCh:
216	case <-time.After(2 * time.Second):
217		t.Fatal("no data in sink")
218	}
219	require.Len(t, n.hub.UserConnections("24"), 0)
220	require.Equal(t, 0, n.hub.NumSubscribers("test_channel_reconnect"))
221
222	wg.Wait()
223
224	require.Len(t, n.hub.UserConnections("24"), 0)
225	require.Len(t, n.hub.UserConnections("42"), 0)
226	require.Equal(t, 0, n.hub.NumSubscribers("test_channel"))
227	require.Equal(t, 0, n.hub.NumSubscribers("test_channel_reconnect"))
228}
229
230func TestHubDisconnect_ClientWhitelist(t *testing.T) {
231	n := defaultNodeNoHandlers()
232	defer func() { _ = n.Shutdown(context.Background()) }()
233
234	n.OnConnect(func(client *Client) {
235		client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) {
236			cb(SubscribeReply{}, nil)
237		})
238	})
239
240	client := newTestSubscribedClient(t, n, "12", "test_channel")
241	clientToKeep := newTestSubscribedClient(t, n, "12", "test_channel")
242
243	require.Len(t, n.hub.UserConnections("12"), 2)
244	require.Equal(t, 2, n.hub.NumSubscribers("test_channel"))
245
246	shouldBeClosed := make(chan struct{})
247	shouldNotBeClosed := make(chan struct{})
248
249	client.eventHub.disconnectHandler = func(e DisconnectEvent) {
250		close(shouldBeClosed)
251	}
252
253	clientToKeep.eventHub.disconnectHandler = func(e DisconnectEvent) {
254		close(shouldNotBeClosed)
255	}
256
257	whitelist := []string{clientToKeep.ID()}
258
259	// Disconnect not existed user.
260	err := n.hub.disconnect("12", DisconnectConnectionLimit, "", whitelist)
261	require.NoError(t, err)
262
263	select {
264	case <-shouldBeClosed:
265		select {
266		case <-shouldNotBeClosed:
267			require.Fail(t, "client should not be disconnected")
268		case <-time.After(time.Second):
269			require.Len(t, n.hub.UserConnections("12"), 1)
270			require.Equal(t, 1, n.hub.NumSubscribers("test_channel"))
271		}
272	case <-time.After(time.Second):
273		require.Fail(t, "timeout waiting for channel close")
274	}
275}
276
277func TestHubDisconnect_ClientID(t *testing.T) {
278	n := defaultNodeNoHandlers()
279	defer func() { _ = n.Shutdown(context.Background()) }()
280
281	n.OnConnect(func(client *Client) {
282		client.OnSubscribe(func(event SubscribeEvent, cb SubscribeCallback) {
283			cb(SubscribeReply{}, nil)
284		})
285	})
286
287	client := newTestSubscribedClient(t, n, "12", "test_channel")
288	clientToKeep := newTestSubscribedClient(t, n, "12", "test_channel")
289
290	require.Len(t, n.hub.UserConnections("12"), 2)
291	require.Equal(t, 2, n.hub.NumSubscribers("test_channel"))
292
293	shouldBeClosed := make(chan struct{})
294	shouldNotBeClosed := make(chan struct{})
295
296	client.eventHub.disconnectHandler = func(e DisconnectEvent) {
297		close(shouldBeClosed)
298	}
299
300	clientToKeep.eventHub.disconnectHandler = func(e DisconnectEvent) {
301		close(shouldNotBeClosed)
302	}
303
304	clientToDisconnect := client.ID()
305
306	// Disconnect not existed user.
307	err := n.hub.disconnect("12", DisconnectConnectionLimit, clientToDisconnect, nil)
308	require.NoError(t, err)
309
310	select {
311	case <-shouldBeClosed:
312		select {
313		case <-shouldNotBeClosed:
314			require.Fail(t, "client should not be disconnected")
315		case <-time.After(time.Second):
316			require.Len(t, n.hub.UserConnections("12"), 1)
317			require.Equal(t, 1, n.hub.NumSubscribers("test_channel"))
318		}
319	case <-time.After(time.Second):
320		require.Fail(t, "timeout waiting for channel close")
321	}
322}
323
324func TestHubBroadcastPublication(t *testing.T) {
325	tcs := []struct {
326		name         string
327		protocolType ProtocolType
328	}{
329		{name: "JSON", protocolType: ProtocolTypeJSON},
330		{name: "Protobuf", protocolType: ProtocolTypeProtobuf},
331	}
332
333	for _, tc := range tcs {
334		t.Run(tc.name, func(t *testing.T) {
335			n := defaultTestNode()
336			defer func() { _ = n.Shutdown(context.Background()) }()
337
338			client := newTestSubscribedClient(t, n, "42", "test_channel")
339			transport := client.transport.(*testTransport)
340			transport.sink = make(chan []byte, 100)
341			transport.protoType = tc.protocolType
342
343			// Broadcast to not existed channel.
344			err := n.hub.BroadcastPublication(
345				"not_test_channel",
346				&Publication{Data: []byte(`{"data": "broadcast_data"}`)},
347				StreamPosition{},
348			)
349			require.NoError(t, err)
350
351			// Broadcast to existed channel.
352			err = n.hub.BroadcastPublication(
353				"test_channel",
354				&Publication{Data: []byte(`{"data": "broadcast_data"}`)},
355				StreamPosition{},
356			)
357			require.NoError(t, err)
358			select {
359			case data := <-transport.sink:
360				require.True(t, strings.Contains(string(data), "broadcast_data"))
361			case <-time.After(2 * time.Second):
362				t.Fatal("no data in sink")
363			}
364		})
365	}
366}
367
368func TestHubBroadcastJoin(t *testing.T) {
369	tcs := []struct {
370		name         string
371		protocolType ProtocolType
372	}{
373		{name: "JSON", protocolType: ProtocolTypeJSON},
374		{name: "Protobuf", protocolType: ProtocolTypeProtobuf},
375	}
376
377	for _, tc := range tcs {
378		t.Run(tc.name, func(t *testing.T) {
379			n := defaultTestNode()
380			defer func() { _ = n.Shutdown(context.Background()) }()
381
382			client := newTestSubscribedClient(t, n, "42", "test_channel")
383			transport := client.transport.(*testTransport)
384			transport.sink = make(chan []byte, 100)
385			transport.protoType = tc.protocolType
386
387			// Broadcast to not existed channel.
388			err := n.hub.broadcastJoin("not_test_channel", &ClientInfo{ClientID: "broadcast_client"})
389			require.NoError(t, err)
390
391			// Broadcast to existed channel.
392			err = n.hub.broadcastJoin("test_channel", &ClientInfo{ClientID: "broadcast_client"})
393			require.NoError(t, err)
394			select {
395			case data := <-transport.sink:
396				require.True(t, strings.Contains(string(data), "broadcast_client"))
397			case <-time.After(2 * time.Second):
398				t.Fatal("no data in sink")
399			}
400		})
401	}
402}
403
404func TestHubBroadcastLeave(t *testing.T) {
405	tcs := []struct {
406		name         string
407		protocolType ProtocolType
408	}{
409		{name: "JSON", protocolType: ProtocolTypeJSON},
410		{name: "Protobuf", protocolType: ProtocolTypeProtobuf},
411	}
412
413	for _, tc := range tcs {
414		t.Run(tc.name, func(t *testing.T) {
415			n := defaultTestNode()
416			defer func() { _ = n.Shutdown(context.Background()) }()
417
418			client := newTestSubscribedClient(t, n, "42", "test_channel")
419			transport := client.transport.(*testTransport)
420			transport.sink = make(chan []byte, 100)
421			transport.protoType = tc.protocolType
422
423			// Broadcast to not existed channel.
424			err := n.hub.broadcastLeave("not_test_channel", &ClientInfo{ClientID: "broadcast_client"})
425			require.NoError(t, err)
426
427			// Broadcast to existed channel.
428			err = n.hub.broadcastLeave("test_channel", &ClientInfo{ClientID: "broadcast_client"})
429			require.NoError(t, err)
430			select {
431			case data := <-transport.sink:
432				require.Contains(t, string(data), "broadcast_client")
433			case <-time.After(2 * time.Second):
434				t.Fatal("no data in sink")
435			}
436		})
437	}
438}
439
440func TestHubShutdown(t *testing.T) {
441	h := newHub()
442	err := h.shutdown(context.Background())
443	require.NoError(t, err)
444	h = newHub()
445	c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {}))
446	require.NoError(t, err)
447	_ = h.add(c)
448
449	err = h.shutdown(context.Background())
450	require.NoError(t, err)
451
452	ctxCanceled, cancel := context.WithCancel(context.Background())
453	cancel()
454	err = h.shutdown(ctxCanceled)
455	require.EqualError(t, err, "context canceled")
456}
457
458func TestHubSubscriptions(t *testing.T) {
459	h := newHub()
460	c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {}))
461	require.NoError(t, err)
462
463	_, _ = h.addSub("test1", c)
464	_, _ = h.addSub("test2", c)
465	require.Equal(t, 2, h.NumChannels())
466	require.Contains(t, h.Channels(), "test1")
467	require.Contains(t, h.Channels(), "test2")
468	require.NotZero(t, h.NumSubscribers("test1"))
469	require.NotZero(t, h.NumSubscribers("test2"))
470
471	// Not exited sub.
472	removed, err := h.removeSub("not_existed", c)
473	require.NoError(t, err)
474	require.True(t, removed)
475
476	// Exited sub with invalid uid.
477	validUID := c.uid
478	c.uid = "invalid"
479	removed, err = h.removeSub("test1", c)
480	require.NoError(t, err)
481	require.True(t, removed)
482	c.uid = validUID
483
484	// Exited sub.
485	removed, err = h.removeSub("test1", c)
486	require.NoError(t, err)
487	require.True(t, removed)
488
489	// Exited sub.
490	removed, err = h.removeSub("test2", c)
491	require.NoError(t, err)
492	require.True(t, removed)
493
494	require.Equal(t, h.NumChannels(), 0)
495	require.Zero(t, h.NumSubscribers("test1"))
496	require.Zero(t, h.NumSubscribers("test2"))
497}
498
499func TestPreparedReply(t *testing.T) {
500	reply := protocol.Reply{}
501	preparedReply := prepared.NewReply(&reply, protocol.TypeJSON)
502	data := preparedReply.Data()
503	require.NotNil(t, data)
504}
505
506func TestUserConnections(t *testing.T) {
507	h := newHub()
508	c, err := newClient(context.Background(), defaultTestNode(), newTestTransport(func() {}))
509	require.NoError(t, err)
510	_ = h.add(c)
511
512	connections := h.UserConnections(c.UserID())
513	require.Equal(t, h.connShards[index(c.UserID(), numHubShards)].conns, connections)
514}
515
516func TestHubSharding(t *testing.T) {
517	numUsers := numHubShards * 10
518	numChannels := numHubShards * 10
519
520	channels := make([]string, 0, numChannels)
521	for i := 0; i < numChannels; i++ {
522		channels = append(channels, "ch"+strconv.Itoa(i))
523	}
524
525	n := defaultTestNode()
526	defer func() { _ = n.Shutdown(context.Background()) }()
527
528	for j := 0; j < 2; j++ { // two connections from the same user.
529		for i := 0; i < numUsers; i++ {
530			c, err := newClient(context.Background(), n, newTestTransport(func() {}))
531			require.NoError(t, err)
532			c.user = strconv.Itoa(i)
533			require.NoError(t, err)
534			_ = n.hub.add(c)
535			for _, ch := range channels {
536				_, _ = n.hub.addSub(ch, c)
537			}
538		}
539	}
540
541	for i := range n.hub.connShards {
542		require.NotZero(t, n.hub.connShards[i].NumClients())
543		require.NotZero(t, n.hub.connShards[i].NumUsers())
544	}
545	for i := range n.hub.subShards {
546		require.True(t, len(n.hub.subShards[i].subs) > 0)
547	}
548
549	require.Equal(t, numUsers, n.Hub().NumUsers())
550	require.Equal(t, 2*numUsers, n.Hub().NumClients())
551	require.Equal(t, numChannels, n.Hub().NumChannels())
552}
553
554// This benchmark allows to estimate the benefit from Hub sharding.
555// As we have a broadcasting goroutine here it's not very useful to look at
556// total allocations here - it's better to look at operation time.
557func BenchmarkHub_Contention(b *testing.B) {
558	numClients := 100
559	numChannels := 128
560
561	n := defaultTestNodeBenchmark(b)
562
563	var clients []*Client
564	var channels []string
565
566	for i := 0; i < numChannels; i++ {
567		channels = append(channels, "ch"+strconv.Itoa(i))
568	}
569
570	for i := 0; i < numClients; i++ {
571		c, err := newClient(context.Background(), n, newTestTransport(func() {}))
572		require.NoError(b, err)
573		_ = n.hub.add(c)
574		clients = append(clients, c)
575		for _, ch := range channels {
576			_, _ = n.hub.addSub(ch, c)
577		}
578	}
579
580	pub := &Publication{
581		Data: []byte(`{"input": "test"}`),
582	}
583	streamPosition := StreamPosition{}
584
585	b.ResetTimer()
586	b.RunParallel(func(pb *testing.PB) {
587		i := 0
588		for pb.Next() {
589			i++
590			var wg sync.WaitGroup
591			wg.Add(1)
592			go func() {
593				defer wg.Done()
594				_ = n.hub.BroadcastPublication(channels[(i+numChannels/2)%numChannels], pub, streamPosition)
595			}()
596			_, _ = n.hub.addSub(channels[i%numChannels], clients[i%numClients])
597			wg.Wait()
598		}
599	})
600}
601
602var broadcastBenches = []struct {
603	NumSubscribers int
604}{
605	{1000},
606	{10000},
607	{100000},
608}
609
610// BenchmarkHub_MassiveBroadcast allows estimating time to broadcast
611// a single message to many subscribers inside one channel.
612func BenchmarkHub_MassiveBroadcast(b *testing.B) {
613	pub := &Publication{Data: []byte(`{"input": "test"}`)}
614	streamPosition := StreamPosition{}
615
616	for _, tt := range broadcastBenches {
617		numSubscribers := tt.NumSubscribers
618		b.Run(fmt.Sprintf("%d", numSubscribers), func(b *testing.B) {
619			b.ReportAllocs()
620			n := defaultTestNodeBenchmark(b)
621
622			numChannels := 64
623			channels := make([]string, 0, numChannels)
624
625			for i := 0; i < numChannels; i++ {
626				channels = append(channels, "broadcast"+strconv.Itoa(i))
627			}
628
629			sink := make(chan []byte, 1024)
630
631			for i := 0; i < numSubscribers; i++ {
632				t := newTestTransport(func() {})
633				t.setSink(sink)
634				c, err := newClient(context.Background(), n, t)
635				require.NoError(b, err)
636				_ = n.hub.add(c)
637				for _, ch := range channels {
638					_, _ = n.hub.addSub(ch, c)
639				}
640			}
641
642			b.ResetTimer()
643			for i := 0; i < b.N; i++ {
644				var wg sync.WaitGroup
645				wg.Add(1)
646				go func() {
647					defer wg.Done()
648					j := 0
649					for {
650						<-sink
651						j++
652						if j == numSubscribers {
653							break
654						}
655					}
656				}()
657				_ = n.hub.BroadcastPublication(channels[i%numChannels], pub, streamPosition)
658				wg.Wait()
659			}
660		})
661	}
662}
663