1package pubsub
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"math/rand"
8	"sync"
9	"testing"
10	"time"
11
12	"github.com/libp2p/go-libp2p-core/peer"
13)
14
15func getTopics(psubs []*PubSub, topicID string, opts ...TopicOpt) []*Topic {
16	topics := make([]*Topic, len(psubs))
17
18	for i, ps := range psubs {
19		t, err := ps.Join(topicID, opts...)
20		if err != nil {
21			panic(err)
22		}
23		topics[i] = t
24	}
25
26	return topics
27}
28
29func getTopicEvts(topics []*Topic, opts ...TopicEventHandlerOpt) []*TopicEventHandler {
30	handlers := make([]*TopicEventHandler, len(topics))
31
32	for i, t := range topics {
33		h, err := t.EventHandler(opts...)
34		if err != nil {
35			panic(err)
36		}
37		handlers[i] = h
38	}
39
40	return handlers
41}
42
43func TestTopicCloseWithOpenSubscription(t *testing.T) {
44	var sub *Subscription
45	var err error
46	testTopicCloseWithOpenResource(t,
47		func(topic *Topic) {
48			sub, err = topic.Subscribe()
49			if err != nil {
50				t.Fatal(err)
51			}
52		},
53		func() {
54			sub.Cancel()
55		},
56	)
57}
58
59func TestTopicCloseWithOpenEventHandler(t *testing.T) {
60	var evts *TopicEventHandler
61	var err error
62	testTopicCloseWithOpenResource(t,
63		func(topic *Topic) {
64			evts, err = topic.EventHandler()
65			if err != nil {
66				t.Fatal(err)
67			}
68		},
69		func() {
70			evts.Cancel()
71		},
72	)
73}
74
75func TestTopicCloseWithOpenRelay(t *testing.T) {
76	var relayCancel RelayCancelFunc
77	var err error
78	testTopicCloseWithOpenResource(t,
79		func(topic *Topic) {
80			relayCancel, err = topic.Relay()
81			if err != nil {
82				t.Fatal(err)
83			}
84		},
85		func() {
86			relayCancel()
87		},
88	)
89}
90
91func testTopicCloseWithOpenResource(t *testing.T, openResource func(topic *Topic), closeResource func()) {
92	ctx, cancel := context.WithCancel(context.Background())
93	defer cancel()
94
95	const numHosts = 1
96	topicID := "foobar"
97	hosts := getNetHosts(t, ctx, numHosts)
98	ps := getPubsub(ctx, hosts[0])
99
100	// Try create and cancel topic
101	topic, err := ps.Join(topicID)
102	if err != nil {
103		t.Fatal(err)
104	}
105
106	if err := topic.Close(); err != nil {
107		t.Fatal(err)
108	}
109
110	// Try create and cancel topic while there's an outstanding subscription/event handler
111	topic, err = ps.Join(topicID)
112	if err != nil {
113		t.Fatal(err)
114	}
115
116	openResource(topic)
117
118	if err := topic.Close(); err == nil {
119		t.Fatal("expected an error closing a topic with an open resource")
120	}
121
122	// Check if the topic closes properly after closing the resource
123	closeResource()
124	time.Sleep(time.Millisecond * 100)
125
126	if err := topic.Close(); err != nil {
127		t.Fatal(err)
128	}
129}
130
131func TestTopicReuse(t *testing.T) {
132	ctx, cancel := context.WithCancel(context.Background())
133	defer cancel()
134
135	const numHosts = 2
136	topicID := "foobar"
137	hosts := getNetHosts(t, ctx, numHosts)
138
139	sender := getPubsub(ctx, hosts[0], WithDiscovery(&dummyDiscovery{}))
140	receiver := getPubsub(ctx, hosts[1])
141
142	connectAll(t, hosts)
143
144	// Sender creates topic
145	sendTopic, err := sender.Join(topicID)
146	if err != nil {
147		t.Fatal(err)
148	}
149
150	// Receiver creates and subscribes to the topic
151	receiveTopic, err := receiver.Join(topicID)
152	if err != nil {
153		t.Fatal(err)
154	}
155
156	sub, err := receiveTopic.Subscribe()
157	if err != nil {
158		t.Fatal(err)
159	}
160
161	firstMsg := []byte("1")
162	if err := sendTopic.Publish(ctx, firstMsg, WithReadiness(MinTopicSize(1))); err != nil {
163		t.Fatal(err)
164	}
165
166	msg, err := sub.Next(ctx)
167	if err != nil {
168		t.Fatal(err)
169	}
170	if bytes.Compare(msg.GetData(), firstMsg) != 0 {
171		t.Fatal("received incorrect message")
172	}
173
174	if err := sendTopic.Close(); err != nil {
175		t.Fatal(err)
176	}
177
178	// Recreate the same topic
179	newSendTopic, err := sender.Join(topicID)
180	if err != nil {
181		t.Fatal(err)
182	}
183
184	// Try sending data with original topic
185	illegalSend := []byte("illegal")
186	if err := sendTopic.Publish(ctx, illegalSend); err != ErrTopicClosed {
187		t.Fatal(err)
188	}
189
190	timeoutCtx, timeoutCancel := context.WithTimeout(ctx, time.Second*2)
191	defer timeoutCancel()
192	msg, err = sub.Next(timeoutCtx)
193	if err != context.DeadlineExceeded {
194		if err != nil {
195			t.Fatal(err)
196		}
197		if bytes.Compare(msg.GetData(), illegalSend) != 0 {
198			t.Fatal("received incorrect message from illegal topic")
199		}
200		t.Fatal("received message sent by illegal topic")
201	}
202	timeoutCancel()
203
204	// Try cancelling the new topic by using the original topic
205	if err := sendTopic.Close(); err != nil {
206		t.Fatal(err)
207	}
208
209	secondMsg := []byte("2")
210	if err := newSendTopic.Publish(ctx, secondMsg); err != nil {
211		t.Fatal(err)
212	}
213
214	timeoutCtx, timeoutCancel = context.WithTimeout(ctx, time.Second*2)
215	defer timeoutCancel()
216	msg, err = sub.Next(ctx)
217	if err != nil {
218		t.Fatal(err)
219	}
220	if bytes.Compare(msg.GetData(), secondMsg) != 0 {
221		t.Fatal("received incorrect message")
222	}
223}
224
225func TestTopicEventHandlerCancel(t *testing.T) {
226	ctx, cancel := context.WithCancel(context.Background())
227	defer cancel()
228
229	const numHosts = 5
230	topicID := "foobar"
231	hosts := getNetHosts(t, ctx, numHosts)
232	ps := getPubsub(ctx, hosts[0])
233
234	// Try create and cancel topic
235	topic, err := ps.Join(topicID)
236	if err != nil {
237		t.Fatal(err)
238	}
239
240	evts, err := topic.EventHandler()
241	if err != nil {
242		t.Fatal(err)
243	}
244	evts.Cancel()
245	timeoutCtx, timeoutCancel := context.WithTimeout(ctx, time.Second*2)
246	defer timeoutCancel()
247	connectAll(t, hosts)
248	_, err = evts.NextPeerEvent(timeoutCtx)
249	if err != context.DeadlineExceeded {
250		if err != nil {
251			t.Fatal(err)
252		}
253		t.Fatal("received event after cancel")
254	}
255}
256
257func TestSubscriptionJoinNotification(t *testing.T) {
258	ctx, cancel := context.WithCancel(context.Background())
259	defer cancel()
260
261	const numLateSubscribers = 10
262	const numHosts = 20
263	hosts := getNetHosts(t, ctx, numHosts)
264	topics := getTopics(getPubsubs(ctx, hosts), "foobar")
265	evts := getTopicEvts(topics)
266
267	subs := make([]*Subscription, numHosts)
268	topicPeersFound := make([]map[peer.ID]struct{}, numHosts)
269
270	// Have some peers subscribe earlier than other peers.
271	// This exercises whether we get subscription notifications from
272	// existing peers.
273	for i, topic := range topics[numLateSubscribers:] {
274		subch, err := topic.Subscribe()
275		if err != nil {
276			t.Fatal(err)
277		}
278
279		subs[i] = subch
280	}
281
282	connectAll(t, hosts)
283
284	time.Sleep(time.Millisecond * 100)
285
286	// Have the rest subscribe
287	for i, topic := range topics[:numLateSubscribers] {
288		subch, err := topic.Subscribe()
289		if err != nil {
290			t.Fatal(err)
291		}
292
293		subs[i+numLateSubscribers] = subch
294	}
295
296	wg := sync.WaitGroup{}
297	for i := 0; i < numHosts; i++ {
298		peersFound := make(map[peer.ID]struct{})
299		topicPeersFound[i] = peersFound
300		evt := evts[i]
301		wg.Add(1)
302		go func(peersFound map[peer.ID]struct{}) {
303			defer wg.Done()
304			for len(peersFound) < numHosts-1 {
305				event, err := evt.NextPeerEvent(ctx)
306				if err != nil {
307					panic(err)
308				}
309				if event.Type == PeerJoin {
310					peersFound[event.Peer] = struct{}{}
311				}
312			}
313		}(peersFound)
314	}
315
316	wg.Wait()
317	for _, peersFound := range topicPeersFound {
318		if len(peersFound) != numHosts-1 {
319			t.Fatal("incorrect number of peers found")
320		}
321	}
322}
323
324func TestSubscriptionLeaveNotification(t *testing.T) {
325	ctx, cancel := context.WithCancel(context.Background())
326	defer cancel()
327
328	const numHosts = 20
329	hosts := getNetHosts(t, ctx, numHosts)
330	psubs := getPubsubs(ctx, hosts)
331	topics := getTopics(psubs, "foobar")
332	evts := getTopicEvts(topics)
333
334	subs := make([]*Subscription, numHosts)
335	topicPeersFound := make([]map[peer.ID]struct{}, numHosts)
336
337	// Subscribe all peers and wait until they've all been found
338	for i, topic := range topics {
339		subch, err := topic.Subscribe()
340		if err != nil {
341			t.Fatal(err)
342		}
343
344		subs[i] = subch
345	}
346
347	connectAll(t, hosts)
348
349	time.Sleep(time.Millisecond * 100)
350
351	wg := sync.WaitGroup{}
352	for i := 0; i < numHosts; i++ {
353		peersFound := make(map[peer.ID]struct{})
354		topicPeersFound[i] = peersFound
355		evt := evts[i]
356		wg.Add(1)
357		go func(peersFound map[peer.ID]struct{}) {
358			defer wg.Done()
359			for len(peersFound) < numHosts-1 {
360				event, err := evt.NextPeerEvent(ctx)
361				if err != nil {
362					panic(err)
363				}
364				if event.Type == PeerJoin {
365					peersFound[event.Peer] = struct{}{}
366				}
367			}
368		}(peersFound)
369	}
370
371	wg.Wait()
372	for _, peersFound := range topicPeersFound {
373		if len(peersFound) != numHosts-1 {
374			t.Fatal("incorrect number of peers found")
375		}
376	}
377
378	// Test removing peers and verifying that they cause events
379	subs[1].Cancel()
380	hosts[2].Close()
381	psubs[0].BlacklistPeer(hosts[3].ID())
382
383	leavingPeers := make(map[peer.ID]struct{})
384	for len(leavingPeers) < 3 {
385		event, err := evts[0].NextPeerEvent(ctx)
386		if err != nil {
387			t.Fatal(err)
388		}
389		if event.Type == PeerLeave {
390			leavingPeers[event.Peer] = struct{}{}
391		}
392	}
393
394	if _, ok := leavingPeers[hosts[1].ID()]; !ok {
395		t.Fatal(fmt.Errorf("canceling subscription did not cause a leave event"))
396	}
397	if _, ok := leavingPeers[hosts[2].ID()]; !ok {
398		t.Fatal(fmt.Errorf("closing host did not cause a leave event"))
399	}
400	if _, ok := leavingPeers[hosts[3].ID()]; !ok {
401		t.Fatal(fmt.Errorf("blacklisting peer did not cause a leave event"))
402	}
403}
404
405func TestSubscriptionManyNotifications(t *testing.T) {
406	t.Skip("flaky test disabled")
407
408	ctx, cancel := context.WithCancel(context.Background())
409	defer cancel()
410
411	const topic = "foobar"
412
413	const numHosts = 33
414	hosts := getNetHosts(t, ctx, numHosts)
415	topics := getTopics(getPubsubs(ctx, hosts), topic)
416	evts := getTopicEvts(topics)
417
418	subs := make([]*Subscription, numHosts)
419	topicPeersFound := make([]map[peer.ID]struct{}, numHosts)
420
421	// Subscribe all peers except one and wait until they've all been found
422	for i := 1; i < numHosts; i++ {
423		subch, err := topics[i].Subscribe()
424		if err != nil {
425			t.Fatal(err)
426		}
427
428		subs[i] = subch
429	}
430
431	connectAll(t, hosts)
432
433	time.Sleep(time.Millisecond * 100)
434
435	wg := sync.WaitGroup{}
436	for i := 1; i < numHosts; i++ {
437		peersFound := make(map[peer.ID]struct{})
438		topicPeersFound[i] = peersFound
439		evt := evts[i]
440		wg.Add(1)
441		go func(peersFound map[peer.ID]struct{}) {
442			defer wg.Done()
443			for len(peersFound) < numHosts-2 {
444				event, err := evt.NextPeerEvent(ctx)
445				if err != nil {
446					panic(err)
447				}
448				if event.Type == PeerJoin {
449					peersFound[event.Peer] = struct{}{}
450				}
451			}
452		}(peersFound)
453	}
454
455	wg.Wait()
456	for _, peersFound := range topicPeersFound[1:] {
457		if len(peersFound) != numHosts-2 {
458			t.Fatalf("found %d peers, expected %d", len(peersFound), numHosts-2)
459		}
460	}
461
462	// Wait for remaining peer to find other peers
463	remPeerTopic, remPeerEvts := topics[0], evts[0]
464	for len(remPeerTopic.ListPeers()) < numHosts-1 {
465		time.Sleep(time.Millisecond * 100)
466	}
467
468	// Subscribe the remaining peer and check that all the events came through
469	sub, err := remPeerTopic.Subscribe()
470	if err != nil {
471		t.Fatal(err)
472	}
473
474	subs[0] = sub
475
476	peerState := readAllQueuedEvents(ctx, t, remPeerEvts)
477
478	if len(peerState) != numHosts-1 {
479		t.Fatal("incorrect number of peers found")
480	}
481
482	for _, e := range peerState {
483		if e != PeerJoin {
484			t.Fatal("non Join event occurred")
485		}
486	}
487
488	// Unsubscribe all peers except one and check that all the events came through
489	for i := 1; i < numHosts; i++ {
490		subs[i].Cancel()
491	}
492
493	// Wait for remaining peer to disconnect from the other peers
494	for len(topics[0].ListPeers()) != 0 {
495		time.Sleep(time.Millisecond * 100)
496	}
497
498	peerState = readAllQueuedEvents(ctx, t, remPeerEvts)
499
500	if len(peerState) != numHosts-1 {
501		t.Fatal("incorrect number of peers found")
502	}
503
504	for _, e := range peerState {
505		if e != PeerLeave {
506			t.Fatal("non Leave event occurred")
507		}
508	}
509}
510
511func TestSubscriptionNotificationSubUnSub(t *testing.T) {
512	// Resubscribe and Unsubscribe a peers and check the state for consistency
513	ctx, cancel := context.WithCancel(context.Background())
514	defer cancel()
515
516	const topic = "foobar"
517
518	const numHosts = 35
519	hosts := getNetHosts(t, ctx, numHosts)
520	topics := getTopics(getPubsubs(ctx, hosts), topic)
521
522	for i := 1; i < numHosts; i++ {
523		connect(t, hosts[0], hosts[i])
524	}
525	time.Sleep(time.Millisecond * 100)
526
527	notifSubThenUnSub(ctx, t, topics)
528}
529
530func TestTopicRelay(t *testing.T) {
531	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
532	defer cancel()
533
534	const topic = "foobar"
535	const numHosts = 5
536
537	hosts := getNetHosts(t, ctx, numHosts)
538	topics := getTopics(getPubsubs(ctx, hosts), topic)
539
540	// [0.Rel] - [1.Rel] - [2.Sub]
541	//             |
542	//           [3.Rel] - [4.Sub]
543
544	connect(t, hosts[0], hosts[1])
545	connect(t, hosts[1], hosts[2])
546	connect(t, hosts[1], hosts[3])
547	connect(t, hosts[3], hosts[4])
548
549	time.Sleep(time.Millisecond * 100)
550
551	var subs []*Subscription
552
553	for i, topic := range topics {
554		if i == 2 || i == 4 {
555			sub, err := topic.Subscribe()
556			if err != nil {
557				t.Fatal(err)
558			}
559
560			subs = append(subs, sub)
561		} else {
562			_, err := topic.Relay()
563			if err != nil {
564				t.Fatal(err)
565			}
566		}
567	}
568
569	time.Sleep(time.Millisecond * 100)
570
571	for i := 0; i < 100; i++ {
572		msg := []byte("message")
573
574		owner := rand.Intn(len(topics))
575
576		err := topics[owner].Publish(ctx, msg)
577		if err != nil {
578			t.Fatal(err)
579		}
580
581		for _, sub := range subs {
582			received, err := sub.Next(ctx)
583			if err != nil {
584				t.Fatal(err)
585			}
586
587			if !bytes.Equal(msg, received.Data) {
588				t.Fatal("received message is other than expected")
589			}
590		}
591	}
592}
593
594func TestTopicRelayReuse(t *testing.T) {
595	ctx, cancel := context.WithCancel(context.Background())
596	defer cancel()
597
598	const topic = "foobar"
599	const numHosts = 1
600
601	hosts := getNetHosts(t, ctx, numHosts)
602	pubsubs := getPubsubs(ctx, hosts)
603	topics := getTopics(pubsubs, topic)
604
605	relay1Cancel, err := topics[0].Relay()
606	if err != nil {
607		t.Fatal(err)
608	}
609
610	relay2Cancel, err := topics[0].Relay()
611	if err != nil {
612		t.Fatal(err)
613	}
614
615	relay3Cancel, err := topics[0].Relay()
616	if err != nil {
617		t.Fatal(err)
618	}
619
620	time.Sleep(time.Millisecond * 100)
621
622	res := make(chan bool, 1)
623	pubsubs[0].eval <- func() {
624		res <- pubsubs[0].myRelays[topic] == 3
625	}
626
627	isCorrectNumber := <-res
628	if !isCorrectNumber {
629		t.Fatal("incorrect number of relays")
630	}
631
632	// only the first invocation should take effect
633	relay1Cancel()
634	relay1Cancel()
635	relay1Cancel()
636
637	pubsubs[0].eval <- func() {
638		res <- pubsubs[0].myRelays[topic] == 2
639	}
640
641	isCorrectNumber = <-res
642	if !isCorrectNumber {
643		t.Fatal("incorrect number of relays")
644	}
645
646	relay2Cancel()
647	relay3Cancel()
648
649	time.Sleep(time.Millisecond * 100)
650
651	pubsubs[0].eval <- func() {
652		res <- pubsubs[0].myRelays[topic] == 0
653	}
654
655	isCorrectNumber = <-res
656	if !isCorrectNumber {
657		t.Fatal("incorrect number of relays")
658	}
659}
660
661func TestTopicRelayOnClosedTopic(t *testing.T) {
662	ctx, cancel := context.WithCancel(context.Background())
663	defer cancel()
664
665	const topic = "foobar"
666	const numHosts = 1
667
668	hosts := getNetHosts(t, ctx, numHosts)
669	topics := getTopics(getPubsubs(ctx, hosts), topic)
670
671	err := topics[0].Close()
672	if err != nil {
673		t.Fatal(err)
674	}
675
676	_, err = topics[0].Relay()
677	if err == nil {
678		t.Fatalf("error should be returned")
679	}
680}
681
682func notifSubThenUnSub(ctx context.Context, t *testing.T, topics []*Topic) {
683	primaryTopic := topics[0]
684	msgs := make([]*Subscription, len(topics))
685	checkSize := len(topics) - 1
686
687	// Subscribe all peers to the topic
688	var err error
689	for i, topic := range topics {
690		msgs[i], err = topic.Subscribe()
691		if err != nil {
692			t.Fatal(err)
693		}
694	}
695
696	// Wait for the primary peer to be connected to the other peers
697	for len(primaryTopic.ListPeers()) < checkSize {
698		time.Sleep(time.Millisecond * 100)
699	}
700
701	// Unsubscribe all peers except the primary
702	for i := 1; i < checkSize+1; i++ {
703		msgs[i].Cancel()
704	}
705
706	// Wait for the unsubscribe messages to reach the primary peer
707	for len(primaryTopic.ListPeers()) < 0 {
708		time.Sleep(time.Millisecond * 100)
709	}
710
711	// read all available events and verify that there are no events to process
712	// this is because every peer that joined also left
713	primaryEvts, err := primaryTopic.EventHandler()
714	if err != nil {
715		t.Fatal(err)
716	}
717	peerState := readAllQueuedEvents(ctx, t, primaryEvts)
718
719	if len(peerState) != 0 {
720		for p, s := range peerState {
721			fmt.Println(p, s)
722		}
723		t.Fatalf("Received incorrect events. %d extra events", len(peerState))
724	}
725}
726
727func readAllQueuedEvents(ctx context.Context, t *testing.T, evt *TopicEventHandler) map[peer.ID]EventType {
728	peerState := make(map[peer.ID]EventType)
729	for {
730		ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
731		event, err := evt.NextPeerEvent(ctx)
732		cancel()
733
734		if err == context.DeadlineExceeded {
735			break
736		} else if err != nil {
737			t.Fatal(err)
738		}
739
740		e, ok := peerState[event.Peer]
741		if !ok {
742			peerState[event.Peer] = event.Type
743		} else if e != event.Type {
744			delete(peerState, event.Peer)
745		}
746	}
747	return peerState
748}
749