1package pubsub
2
3import (
4	"bytes"
5	"context"
6	"crypto/sha256"
7	"encoding/base64"
8	"fmt"
9	"io"
10	"math/rand"
11	"sort"
12	"sync"
13	"testing"
14	"time"
15
16	pb "github.com/libp2p/go-libp2p-pubsub/pb"
17
18	"github.com/libp2p/go-libp2p-core/host"
19	"github.com/libp2p/go-libp2p-core/network"
20	"github.com/libp2p/go-libp2p-core/peer"
21	"github.com/libp2p/go-libp2p-core/protocol"
22
23	bhost "github.com/libp2p/go-libp2p-blankhost"
24	swarmt "github.com/libp2p/go-libp2p-swarm/testing"
25
26	"github.com/libp2p/go-msgio/protoio"
27)
28
29func checkMessageRouting(t *testing.T, topic string, pubs []*PubSub, subs []*Subscription) {
30	data := make([]byte, 16)
31	rand.Read(data)
32
33	for _, p := range pubs {
34		err := p.Publish(topic, data)
35		if err != nil {
36			t.Fatal(err)
37		}
38
39		for _, s := range subs {
40			assertReceive(t, s, data)
41		}
42	}
43}
44
45func getNetHosts(t *testing.T, ctx context.Context, n int) []host.Host {
46	var out []host.Host
47
48	for i := 0; i < n; i++ {
49		netw := swarmt.GenSwarm(t, ctx)
50		h := bhost.NewBlankHost(netw)
51		out = append(out, h)
52	}
53
54	return out
55}
56
57func connect(t *testing.T, a, b host.Host) {
58	pinfo := a.Peerstore().PeerInfo(a.ID())
59	err := b.Connect(context.Background(), pinfo)
60	if err != nil {
61		t.Fatal(err)
62	}
63}
64
65func sparseConnect(t *testing.T, hosts []host.Host) {
66	connectSome(t, hosts, 3)
67}
68
69func denseConnect(t *testing.T, hosts []host.Host) {
70	connectSome(t, hosts, 10)
71}
72
73func connectSome(t *testing.T, hosts []host.Host, d int) {
74	for i, a := range hosts {
75		for j := 0; j < d; j++ {
76			n := rand.Intn(len(hosts))
77			if n == i {
78				j--
79				continue
80			}
81
82			b := hosts[n]
83
84			connect(t, a, b)
85		}
86	}
87}
88
89func connectAll(t *testing.T, hosts []host.Host) {
90	for i, a := range hosts {
91		for j, b := range hosts {
92			if i == j {
93				continue
94			}
95
96			connect(t, a, b)
97		}
98	}
99}
100
101func getPubsub(ctx context.Context, h host.Host, opts ...Option) *PubSub {
102	ps, err := NewFloodSub(ctx, h, opts...)
103	if err != nil {
104		panic(err)
105	}
106	return ps
107}
108
109func getPubsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSub {
110	var psubs []*PubSub
111	for _, h := range hs {
112		psubs = append(psubs, getPubsub(ctx, h, opts...))
113	}
114	return psubs
115}
116
117func assertReceive(t *testing.T, ch *Subscription, exp []byte) {
118	select {
119	case msg := <-ch.ch:
120		if !bytes.Equal(msg.GetData(), exp) {
121			t.Fatalf("got wrong message, expected %s but got %s", string(exp), string(msg.GetData()))
122		}
123	case <-time.After(time.Second * 5):
124		t.Logf("%#v\n", ch)
125		t.Fatal("timed out waiting for message of: ", string(exp))
126	}
127}
128
129func TestBasicFloodsub(t *testing.T) {
130	ctx, cancel := context.WithCancel(context.Background())
131	defer cancel()
132	hosts := getNetHosts(t, ctx, 20)
133
134	psubs := getPubsubs(ctx, hosts)
135
136	var msgs []*Subscription
137	for _, ps := range psubs {
138		subch, err := ps.Subscribe("foobar")
139		if err != nil {
140			t.Fatal(err)
141		}
142
143		msgs = append(msgs, subch)
144	}
145
146	//connectAll(t, hosts)
147	sparseConnect(t, hosts)
148
149	time.Sleep(time.Millisecond * 100)
150
151	for i := 0; i < 100; i++ {
152		msg := []byte(fmt.Sprintf("%d the flooooooood %d", i, i))
153
154		owner := rand.Intn(len(psubs))
155
156		psubs[owner].Publish("foobar", msg)
157
158		for _, sub := range msgs {
159			got, err := sub.Next(ctx)
160			if err != nil {
161				t.Fatal(sub.err)
162			}
163			if !bytes.Equal(msg, got.Data) {
164				t.Fatal("got wrong message!")
165			}
166		}
167	}
168
169}
170
171func TestMultihops(t *testing.T) {
172	ctx, cancel := context.WithCancel(context.Background())
173	defer cancel()
174
175	hosts := getNetHosts(t, ctx, 6)
176
177	psubs := getPubsubs(ctx, hosts)
178
179	connect(t, hosts[0], hosts[1])
180	connect(t, hosts[1], hosts[2])
181	connect(t, hosts[2], hosts[3])
182	connect(t, hosts[3], hosts[4])
183	connect(t, hosts[4], hosts[5])
184
185	var subs []*Subscription
186	for i := 1; i < 6; i++ {
187		ch, err := psubs[i].Subscribe("foobar")
188		if err != nil {
189			t.Fatal(err)
190		}
191		subs = append(subs, ch)
192	}
193
194	time.Sleep(time.Millisecond * 100)
195
196	msg := []byte("i like cats")
197	err := psubs[0].Publish("foobar", msg)
198	if err != nil {
199		t.Fatal(err)
200	}
201
202	// last node in the chain should get the message
203	select {
204	case out := <-subs[4].ch:
205		if !bytes.Equal(out.GetData(), msg) {
206			t.Fatal("got wrong data")
207		}
208	case <-time.After(time.Second * 5):
209		t.Fatal("timed out waiting for message")
210	}
211}
212
213func TestReconnects(t *testing.T) {
214	ctx, cancel := context.WithCancel(context.Background())
215	defer cancel()
216
217	hosts := getNetHosts(t, ctx, 3)
218
219	psubs := getPubsubs(ctx, hosts)
220
221	connect(t, hosts[0], hosts[1])
222	connect(t, hosts[0], hosts[2])
223
224	A, err := psubs[1].Subscribe("cats")
225	if err != nil {
226		t.Fatal(err)
227	}
228
229	B, err := psubs[2].Subscribe("cats")
230	if err != nil {
231		t.Fatal(err)
232	}
233
234	time.Sleep(time.Millisecond * 100)
235
236	msg := []byte("apples and oranges")
237	err = psubs[0].Publish("cats", msg)
238	if err != nil {
239		t.Fatal(err)
240	}
241
242	assertReceive(t, A, msg)
243	assertReceive(t, B, msg)
244
245	B.Cancel()
246
247	time.Sleep(time.Millisecond * 50)
248
249	msg2 := []byte("potato")
250	err = psubs[0].Publish("cats", msg2)
251	if err != nil {
252		t.Fatal(err)
253	}
254
255	assertReceive(t, A, msg2)
256	select {
257	case _, ok := <-B.ch:
258		if ok {
259			t.Fatal("shouldnt have gotten data on this channel")
260		}
261	case <-time.After(time.Second):
262		t.Fatal("timed out waiting for B chan to be closed")
263	}
264
265	nSubs := len(psubs[2].mySubs["cats"])
266	if nSubs > 0 {
267		t.Fatal(`B should have 0 subscribers for channel "cats", has`, nSubs)
268	}
269
270	ch2, err := psubs[2].Subscribe("cats")
271	if err != nil {
272		t.Fatal(err)
273	}
274
275	time.Sleep(time.Millisecond * 100)
276
277	nextmsg := []byte("ifps is kul")
278	err = psubs[0].Publish("cats", nextmsg)
279	if err != nil {
280		t.Fatal(err)
281	}
282
283	assertReceive(t, ch2, nextmsg)
284}
285
286// make sure messages arent routed between nodes who arent subscribed
287func TestNoConnection(t *testing.T) {
288	ctx, cancel := context.WithCancel(context.Background())
289	defer cancel()
290
291	hosts := getNetHosts(t, ctx, 10)
292
293	psubs := getPubsubs(ctx, hosts)
294
295	ch, err := psubs[5].Subscribe("foobar")
296	if err != nil {
297		t.Fatal(err)
298	}
299
300	err = psubs[0].Publish("foobar", []byte("TESTING"))
301	if err != nil {
302		t.Fatal(err)
303	}
304
305	select {
306	case <-ch.ch:
307		t.Fatal("shouldnt have gotten a message")
308	case <-time.After(time.Millisecond * 200):
309	}
310}
311
312func TestSelfReceive(t *testing.T) {
313	ctx, cancel := context.WithCancel(context.Background())
314	defer cancel()
315
316	host := getNetHosts(t, ctx, 1)[0]
317
318	psub, err := NewFloodSub(ctx, host)
319	if err != nil {
320		t.Fatal(err)
321	}
322
323	msg := []byte("hello world")
324
325	err = psub.Publish("foobar", msg)
326	if err != nil {
327		t.Fatal(err)
328	}
329
330	time.Sleep(time.Millisecond * 10)
331
332	ch, err := psub.Subscribe("foobar")
333	if err != nil {
334		t.Fatal(err)
335	}
336
337	msg2 := []byte("goodbye world")
338	err = psub.Publish("foobar", msg2)
339	if err != nil {
340		t.Fatal(err)
341	}
342
343	assertReceive(t, ch, msg2)
344}
345
346func TestOneToOne(t *testing.T) {
347	ctx, cancel := context.WithCancel(context.Background())
348	defer cancel()
349
350	hosts := getNetHosts(t, ctx, 2)
351	psubs := getPubsubs(ctx, hosts)
352
353	connect(t, hosts[0], hosts[1])
354
355	sub, err := psubs[1].Subscribe("foobar")
356	if err != nil {
357		t.Fatal(err)
358	}
359
360	time.Sleep(time.Millisecond * 50)
361
362	checkMessageRouting(t, "foobar", psubs, []*Subscription{sub})
363}
364
365func assertPeerLists(t *testing.T, hosts []host.Host, ps *PubSub, has ...int) {
366	peers := ps.ListPeers("")
367	set := make(map[peer.ID]struct{})
368	for _, p := range peers {
369		set[p] = struct{}{}
370	}
371
372	for _, h := range has {
373		if _, ok := set[hosts[h].ID()]; !ok {
374			t.Fatal("expected to have connection to peer: ", h)
375		}
376	}
377}
378
379func TestTreeTopology(t *testing.T) {
380	ctx, cancel := context.WithCancel(context.Background())
381	defer cancel()
382
383	hosts := getNetHosts(t, ctx, 10)
384	psubs := getPubsubs(ctx, hosts)
385
386	connect(t, hosts[0], hosts[1])
387	connect(t, hosts[1], hosts[2])
388	connect(t, hosts[1], hosts[4])
389	connect(t, hosts[2], hosts[3])
390	connect(t, hosts[0], hosts[5])
391	connect(t, hosts[5], hosts[6])
392	connect(t, hosts[5], hosts[8])
393	connect(t, hosts[6], hosts[7])
394	connect(t, hosts[8], hosts[9])
395
396	/*
397		[0] -> [1] -> [2] -> [3]
398		 |      L->[4]
399		 v
400		[5] -> [6] -> [7]
401		 |
402		 v
403		[8] -> [9]
404	*/
405
406	var chs []*Subscription
407	for _, ps := range psubs {
408		ch, err := ps.Subscribe("fizzbuzz")
409		if err != nil {
410			t.Fatal(err)
411		}
412
413		chs = append(chs, ch)
414	}
415
416	time.Sleep(time.Millisecond * 50)
417
418	assertPeerLists(t, hosts, psubs[0], 1, 5)
419	assertPeerLists(t, hosts, psubs[1], 0, 2, 4)
420	assertPeerLists(t, hosts, psubs[2], 1, 3)
421
422	checkMessageRouting(t, "fizzbuzz", []*PubSub{psubs[9], psubs[3]}, chs)
423}
424
425func assertHasTopics(t *testing.T, ps *PubSub, exptopics ...string) {
426	topics := ps.GetTopics()
427	sort.Strings(topics)
428	sort.Strings(exptopics)
429
430	if len(topics) != len(exptopics) {
431		t.Fatalf("expected to have %v, but got %v", exptopics, topics)
432	}
433
434	for i, v := range exptopics {
435		if topics[i] != v {
436			t.Fatalf("expected %s but have %s", v, topics[i])
437		}
438	}
439}
440
441func TestFloodSubPluggableProtocol(t *testing.T) {
442	t.Run("multi-procol router acts like a hub", func(t *testing.T) {
443		ctx, cancel := context.WithCancel(context.Background())
444		defer cancel()
445
446		hosts := getNetHosts(t, ctx, 3)
447
448		psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub", "/lsr/floodsub")
449		psubB := mustCreatePubSub(ctx, t, hosts[1], "/esh/floodsub")
450		psubC := mustCreatePubSub(ctx, t, hosts[2], "/lsr/floodsub")
451
452		subA := mustSubscribe(t, psubA, "foobar")
453		defer subA.Cancel()
454
455		subB := mustSubscribe(t, psubB, "foobar")
456		defer subB.Cancel()
457
458		subC := mustSubscribe(t, psubC, "foobar")
459		defer subC.Cancel()
460
461		// B --> A, C --> A
462		connect(t, hosts[1], hosts[0])
463		connect(t, hosts[2], hosts[0])
464
465		time.Sleep(time.Millisecond * 100)
466
467		psubC.Publish("foobar", []byte("bar"))
468
469		assertReceive(t, subA, []byte("bar"))
470		assertReceive(t, subB, []byte("bar"))
471		assertReceive(t, subC, []byte("bar"))
472	})
473
474	t.Run("won't talk to routers with no protocol overlap", func(t *testing.T) {
475		ctx, cancel := context.WithCancel(context.Background())
476		defer cancel()
477
478		hosts := getNetHosts(t, ctx, 2)
479
480		psubA := mustCreatePubSub(ctx, t, hosts[0], "/esh/floodsub")
481		psubB := mustCreatePubSub(ctx, t, hosts[1], "/lsr/floodsub")
482
483		subA := mustSubscribe(t, psubA, "foobar")
484		defer subA.Cancel()
485
486		subB := mustSubscribe(t, psubB, "foobar")
487		defer subB.Cancel()
488
489		connect(t, hosts[1], hosts[0])
490
491		time.Sleep(time.Millisecond * 100)
492
493		psubA.Publish("foobar", []byte("bar"))
494
495		assertReceive(t, subA, []byte("bar"))
496
497		pass := false
498		select {
499		case <-subB.ch:
500			t.Fatal("different protocols: should not have received message")
501		case <-time.After(time.Second * 1):
502			pass = true
503		}
504
505		if !pass {
506			t.Fatal("should have timed out waiting for message")
507		}
508	})
509}
510
511func mustCreatePubSub(ctx context.Context, t *testing.T, h host.Host, ps ...protocol.ID) *PubSub {
512	psub, err := NewFloodsubWithProtocols(ctx, h, ps)
513	if err != nil {
514		t.Fatal(err)
515	}
516
517	return psub
518}
519
520func mustSubscribe(t *testing.T, ps *PubSub, topic string) *Subscription {
521	sub, err := ps.Subscribe(topic)
522	if err != nil {
523		t.Fatal(err)
524	}
525
526	return sub
527}
528
529func TestSubReporting(t *testing.T) {
530	ctx, cancel := context.WithCancel(context.Background())
531	defer cancel()
532
533	host := getNetHosts(t, ctx, 1)[0]
534	psub, err := NewFloodSub(ctx, host)
535	if err != nil {
536		t.Fatal(err)
537	}
538
539	fooSub, err := psub.Subscribe("foo")
540	if err != nil {
541		t.Fatal(err)
542	}
543
544	barSub, err := psub.Subscribe("bar")
545	if err != nil {
546		t.Fatal(err)
547	}
548
549	assertHasTopics(t, psub, "foo", "bar")
550
551	_, err = psub.Subscribe("baz")
552	if err != nil {
553		t.Fatal(err)
554	}
555
556	assertHasTopics(t, psub, "foo", "bar", "baz")
557
558	barSub.Cancel()
559	assertHasTopics(t, psub, "foo", "baz")
560	fooSub.Cancel()
561	assertHasTopics(t, psub, "baz")
562
563	_, err = psub.Subscribe("fish")
564	if err != nil {
565		t.Fatal(err)
566	}
567
568	assertHasTopics(t, psub, "baz", "fish")
569}
570
571func TestPeerTopicReporting(t *testing.T) {
572	ctx, cancel := context.WithCancel(context.Background())
573	defer cancel()
574
575	hosts := getNetHosts(t, ctx, 4)
576	psubs := getPubsubs(ctx, hosts)
577
578	connect(t, hosts[0], hosts[1])
579	connect(t, hosts[0], hosts[2])
580	connect(t, hosts[0], hosts[3])
581
582	_, err := psubs[1].Subscribe("foo")
583	if err != nil {
584		t.Fatal(err)
585	}
586	_, err = psubs[1].Subscribe("bar")
587	if err != nil {
588		t.Fatal(err)
589	}
590	_, err = psubs[1].Subscribe("baz")
591	if err != nil {
592		t.Fatal(err)
593	}
594
595	_, err = psubs[2].Subscribe("foo")
596	if err != nil {
597		t.Fatal(err)
598	}
599	_, err = psubs[2].Subscribe("ipfs")
600	if err != nil {
601		t.Fatal(err)
602	}
603
604	_, err = psubs[3].Subscribe("baz")
605	if err != nil {
606		t.Fatal(err)
607	}
608	_, err = psubs[3].Subscribe("ipfs")
609	if err != nil {
610		t.Fatal(err)
611	}
612
613	time.Sleep(time.Millisecond * 200)
614
615	peers := psubs[0].ListPeers("ipfs")
616	assertPeerList(t, peers, hosts[2].ID(), hosts[3].ID())
617
618	peers = psubs[0].ListPeers("foo")
619	assertPeerList(t, peers, hosts[1].ID(), hosts[2].ID())
620
621	peers = psubs[0].ListPeers("baz")
622	assertPeerList(t, peers, hosts[1].ID(), hosts[3].ID())
623
624	peers = psubs[0].ListPeers("bar")
625	assertPeerList(t, peers, hosts[1].ID())
626}
627
628func TestSubscribeMultipleTimes(t *testing.T) {
629	ctx, cancel := context.WithCancel(context.Background())
630	defer cancel()
631
632	hosts := getNetHosts(t, ctx, 2)
633	psubs := getPubsubs(ctx, hosts)
634
635	connect(t, hosts[0], hosts[1])
636
637	sub1, err := psubs[0].Subscribe("foo")
638	if err != nil {
639		t.Fatal(err)
640	}
641	sub2, err := psubs[0].Subscribe("foo")
642	if err != nil {
643		t.Fatal(err)
644	}
645
646	// make sure subscribing is finished by the time we publish
647	time.Sleep(10 * time.Millisecond)
648
649	psubs[1].Publish("foo", []byte("bar"))
650
651	msg, err := sub1.Next(ctx)
652	if err != nil {
653		t.Fatalf("unexpected error: %v.", err)
654	}
655
656	data := string(msg.GetData())
657
658	if data != "bar" {
659		t.Fatalf("data is %s, expected %s.", data, "bar")
660	}
661
662	msg, err = sub2.Next(ctx)
663	if err != nil {
664		t.Fatalf("unexpected error: %v.", err)
665	}
666	data = string(msg.GetData())
667
668	if data != "bar" {
669		t.Fatalf("data is %s, expected %s.", data, "bar")
670	}
671}
672
673func TestPeerDisconnect(t *testing.T) {
674	ctx, cancel := context.WithCancel(context.Background())
675	defer cancel()
676
677	hosts := getNetHosts(t, ctx, 2)
678	psubs := getPubsubs(ctx, hosts)
679
680	connect(t, hosts[0], hosts[1])
681
682	_, err := psubs[0].Subscribe("foo")
683	if err != nil {
684		t.Fatal(err)
685	}
686
687	_, err = psubs[1].Subscribe("foo")
688	if err != nil {
689		t.Fatal(err)
690	}
691
692	time.Sleep(time.Millisecond * 300)
693
694	peers := psubs[0].ListPeers("foo")
695	assertPeerList(t, peers, hosts[1].ID())
696	for _, c := range hosts[1].Network().ConnsToPeer(hosts[0].ID()) {
697		c.Close()
698	}
699
700	time.Sleep(time.Millisecond * 300)
701
702	peers = psubs[0].ListPeers("foo")
703	assertPeerList(t, peers)
704}
705
706func assertPeerList(t *testing.T, peers []peer.ID, expected ...peer.ID) {
707	sort.Sort(peer.IDSlice(peers))
708	sort.Sort(peer.IDSlice(expected))
709
710	if len(peers) != len(expected) {
711		t.Fatalf("mismatch: %s != %s", peers, expected)
712	}
713
714	for i, p := range peers {
715		if expected[i] != p {
716			t.Fatalf("mismatch: %s != %s", peers, expected)
717		}
718	}
719}
720
721func TestWithNoSigning(t *testing.T) {
722	ctx, cancel := context.WithCancel(context.Background())
723	defer cancel()
724
725	hosts := getNetHosts(t, ctx, 2)
726	psubs := getPubsubs(ctx, hosts, WithNoAuthor(), WithMessageIdFn(func(pmsg *pb.Message) string {
727		// silly content-based test message-ID: just use the data as whole
728		return base64.URLEncoding.EncodeToString(pmsg.Data)
729	}))
730
731	connect(t, hosts[0], hosts[1])
732
733	topic := "foobar"
734	data := []byte("this is a message")
735
736	sub, err := psubs[1].Subscribe(topic)
737	if err != nil {
738		t.Fatal(err)
739	}
740
741	time.Sleep(time.Millisecond * 10)
742
743	err = psubs[0].Publish(topic, data)
744	if err != nil {
745		t.Fatal(err)
746	}
747
748	msg, err := sub.Next(ctx)
749	if err != nil {
750		t.Fatal(err)
751	}
752	if msg.Signature != nil {
753		t.Fatal("signature in message")
754	}
755	if msg.From != nil {
756		t.Fatal("from in message")
757	}
758	if msg.Seqno != nil {
759		t.Fatal("seqno in message")
760	}
761	if string(msg.Data) != string(data) {
762		t.Fatalf("unexpected data: %s", string(msg.Data))
763	}
764}
765
766func TestWithSigning(t *testing.T) {
767	ctx, cancel := context.WithCancel(context.Background())
768	defer cancel()
769
770	hosts := getNetHosts(t, ctx, 2)
771	psubs := getPubsubs(ctx, hosts, WithStrictSignatureVerification(true))
772
773	connect(t, hosts[0], hosts[1])
774
775	topic := "foobar"
776	data := []byte("this is a message")
777
778	sub, err := psubs[1].Subscribe(topic)
779	if err != nil {
780		t.Fatal(err)
781	}
782
783	time.Sleep(time.Millisecond * 10)
784
785	err = psubs[0].Publish(topic, data)
786	if err != nil {
787		t.Fatal(err)
788	}
789
790	msg, err := sub.Next(ctx)
791	if err != nil {
792		t.Fatal(err)
793	}
794	if msg.Signature == nil {
795		t.Fatal("no signature in message")
796	}
797	if msg.From == nil {
798		t.Fatal("from not in message")
799	}
800	if msg.Seqno == nil {
801		t.Fatal("seqno not in message")
802	}
803	if string(msg.Data) != string(data) {
804		t.Fatalf("unexpected data: %s", string(msg.Data))
805	}
806}
807
808func TestImproperlySignedMessageRejected(t *testing.T) {
809	ctx, cancel := context.WithCancel(context.Background())
810	defer cancel()
811
812	hosts := getNetHosts(t, ctx, 2)
813	adversary := hosts[0]
814	honestPeer := hosts[1]
815
816	// The adversary enables signing, but disables verification to let through
817	// an incorrectly signed message.
818	adversaryPubSub := getPubsub(
819		ctx,
820		adversary,
821		WithMessageSigning(true),
822		WithStrictSignatureVerification(false),
823	)
824	honestPubSub := getPubsub(
825		ctx,
826		honestPeer,
827		WithStrictSignatureVerification(true),
828	)
829
830	connect(t, adversary, honestPeer)
831
832	var (
833		topic            = "foobar"
834		correctMessage   = []byte("this is a correct message")
835		incorrectMessage = []byte("this is the incorrect message")
836	)
837
838	adversarySubscription, err := adversaryPubSub.Subscribe(topic)
839	if err != nil {
840		t.Fatal(err)
841	}
842	honestPeerSubscription, err := honestPubSub.Subscribe(topic)
843	if err != nil {
844		t.Fatal(err)
845	}
846	time.Sleep(time.Millisecond * 50)
847
848	// First the adversary sends the correct message.
849	err = adversaryPubSub.Publish(topic, correctMessage)
850	if err != nil {
851		t.Fatal(err)
852	}
853
854	// Change the sign key for the adversarial peer, and send the second,
855	// incorrectly signed, message.
856	adversaryPubSub.signID = honestPubSub.signID
857	adversaryPubSub.signKey = honestPubSub.host.Peerstore().PrivKey(honestPubSub.signID)
858	err = adversaryPubSub.Publish(topic, incorrectMessage)
859	if err != nil {
860		t.Fatal(err)
861	}
862
863	var adversaryMessages []*Message
864	adversaryContext, adversaryCancel := context.WithCancel(ctx)
865	go func(ctx context.Context) {
866		for {
867			select {
868			case <-ctx.Done():
869				return
870			default:
871				msg, err := adversarySubscription.Next(ctx)
872				if err != nil {
873					return
874				}
875				adversaryMessages = append(adversaryMessages, msg)
876			}
877		}
878	}(adversaryContext)
879
880	<-time.After(1 * time.Second)
881	adversaryCancel()
882
883	// Ensure the adversary successfully publishes the incorrectly signed
884	// message. If the adversary "sees" this, we successfully got through
885	// their local validation.
886	if len(adversaryMessages) != 2 {
887		t.Fatalf("got %d messages, expected 2", len(adversaryMessages))
888	}
889
890	// the honest peer's validation process will drop the message;
891	// next will never furnish the incorrect message.
892	var honestPeerMessages []*Message
893	honestPeerContext, honestPeerCancel := context.WithCancel(ctx)
894	go func(ctx context.Context) {
895		for {
896			select {
897			case <-ctx.Done():
898				return
899			default:
900				msg, err := honestPeerSubscription.Next(ctx)
901				if err != nil {
902					return
903				}
904				honestPeerMessages = append(honestPeerMessages, msg)
905			}
906		}
907	}(honestPeerContext)
908
909	<-time.After(1 * time.Second)
910	honestPeerCancel()
911
912	if len(honestPeerMessages) != 1 {
913		t.Fatalf("got %d messages, expected 1", len(honestPeerMessages))
914	}
915	if string(honestPeerMessages[0].GetData()) != string(correctMessage) {
916		t.Fatalf(
917			"got %s, expected message %s",
918			honestPeerMessages[0].GetData(),
919			correctMessage,
920		)
921	}
922}
923
924func TestMessageSender(t *testing.T) {
925	ctx, cancel := context.WithCancel(context.Background())
926	defer cancel()
927
928	const topic = "foobar"
929
930	hosts := getNetHosts(t, ctx, 3)
931	psubs := getPubsubs(ctx, hosts)
932
933	var msgs []*Subscription
934	for _, ps := range psubs {
935		subch, err := ps.Subscribe(topic)
936		if err != nil {
937			t.Fatal(err)
938		}
939
940		msgs = append(msgs, subch)
941	}
942
943	connect(t, hosts[0], hosts[1])
944	connect(t, hosts[1], hosts[2])
945
946	time.Sleep(time.Millisecond * 100)
947
948	for i := 0; i < 3; i++ {
949		for j := 0; j < 100; j++ {
950			msg := []byte(fmt.Sprintf("%d sent %d", i, j))
951
952			psubs[i].Publish(topic, msg)
953
954			for k, sub := range msgs {
955				got, err := sub.Next(ctx)
956				if err != nil {
957					t.Fatal(sub.err)
958				}
959				if !bytes.Equal(msg, got.Data) {
960					t.Fatal("got wrong message!")
961				}
962
963				var expectedHost int
964				if i == k {
965					expectedHost = i
966				} else if k != 1 {
967					expectedHost = 1
968				} else {
969					expectedHost = i
970				}
971
972				if got.ReceivedFrom != hosts[expectedHost].ID() {
973					t.Fatal("got wrong message sender")
974				}
975			}
976		}
977	}
978}
979
980func TestConfigurableMaxMessageSize(t *testing.T) {
981	ctx, cancel := context.WithCancel(context.Background())
982	defer cancel()
983
984	hosts := getNetHosts(t, ctx, 10)
985
986	// use a 4mb limit; default is 1mb; we'll test with a 2mb payload.
987	psubs := getPubsubs(ctx, hosts, WithMaxMessageSize(1<<22))
988
989	sparseConnect(t, hosts)
990	time.Sleep(time.Millisecond * 100)
991
992	const topic = "foobar"
993	var subs []*Subscription
994	for _, ps := range psubs {
995		subch, err := ps.Subscribe(topic)
996		if err != nil {
997			t.Fatal(err)
998		}
999		subs = append(subs, subch)
1000	}
1001
1002	// 2mb payload.
1003	msg := make([]byte, 1<<21)
1004	rand.Read(msg)
1005	err := psubs[0].Publish(topic, msg)
1006	if err != nil {
1007		t.Fatal(err)
1008	}
1009
1010	// make sure that all peers received the message.
1011	for _, sub := range subs {
1012		got, err := sub.Next(ctx)
1013		if err != nil {
1014			t.Fatal(sub.err)
1015		}
1016		if !bytes.Equal(msg, got.Data) {
1017			t.Fatal("got wrong message!")
1018		}
1019	}
1020
1021}
1022
1023func TestAnnounceRetry(t *testing.T) {
1024	ctx, cancel := context.WithCancel(context.Background())
1025	defer cancel()
1026
1027	hosts := getNetHosts(t, ctx, 2)
1028	ps := getPubsub(ctx, hosts[0])
1029	watcher := &announceWatcher{}
1030	hosts[1].SetStreamHandler(FloodSubID, watcher.handleStream)
1031
1032	_, err := ps.Subscribe("test")
1033	if err != nil {
1034		t.Fatal(err)
1035	}
1036
1037	// connect the watcher to the pubsub
1038	connect(t, hosts[0], hosts[1])
1039
1040	// wait a bit for the first subscription to be emitted and trigger announce retry
1041	time.Sleep(100 * time.Millisecond)
1042	go ps.announceRetry(hosts[1].ID(), "test", true)
1043
1044	// wait a bit for the subscription to propagate and ensure it was received twice
1045	time.Sleep(time.Second + 100*time.Millisecond)
1046	count := watcher.countSubs()
1047	if count != 2 {
1048		t.Fatalf("expected 2 subscription messages, but got %d", count)
1049	}
1050}
1051
1052type announceWatcher struct {
1053	mx   sync.Mutex
1054	subs int
1055}
1056
1057func (aw *announceWatcher) handleStream(s network.Stream) {
1058	defer s.Close()
1059
1060	r := protoio.NewDelimitedReader(s, 1<<20)
1061
1062	var rpc pb.RPC
1063	for {
1064		rpc.Reset()
1065		err := r.ReadMsg(&rpc)
1066		if err != nil {
1067			if err != io.EOF {
1068				s.Reset()
1069			}
1070			return
1071		}
1072
1073		for _, sub := range rpc.GetSubscriptions() {
1074			if sub.GetSubscribe() && sub.GetTopicid() == "test" {
1075				aw.mx.Lock()
1076				aw.subs++
1077				aw.mx.Unlock()
1078			}
1079		}
1080	}
1081}
1082
1083func (aw *announceWatcher) countSubs() int {
1084	aw.mx.Lock()
1085	defer aw.mx.Unlock()
1086	return aw.subs
1087}
1088
1089func TestPubsubWithAssortedOptions(t *testing.T) {
1090	// this test uses assorted options that are not covered in other tests
1091	ctx, cancel := context.WithCancel(context.Background())
1092	defer cancel()
1093
1094	hashMsgID := func(m *pb.Message) string {
1095		hash := sha256.Sum256(m.Data)
1096		return string(hash[:])
1097	}
1098
1099	hosts := getNetHosts(t, ctx, 2)
1100	psubs := getPubsubs(ctx, hosts,
1101		WithMessageIdFn(hashMsgID),
1102		WithPeerOutboundQueueSize(10),
1103		WithMessageAuthor(""),
1104		WithBlacklist(NewMapBlacklist()))
1105
1106	connect(t, hosts[0], hosts[1])
1107
1108	var subs []*Subscription
1109	for _, ps := range psubs {
1110		sub, err := ps.Subscribe("test")
1111		if err != nil {
1112			t.Fatal(err)
1113		}
1114		subs = append(subs, sub)
1115	}
1116
1117	time.Sleep(time.Second)
1118
1119	for i := 0; i < 2; i++ {
1120		msg := []byte(fmt.Sprintf("message %d", i))
1121		psubs[i].Publish("test", msg)
1122
1123		for _, sub := range subs {
1124			assertReceive(t, sub, msg)
1125		}
1126	}
1127}
1128
1129func TestWithInvalidMessageAuthor(t *testing.T) {
1130	// this test exercises the failure path in the WithMessageAuthor option
1131	ctx, cancel := context.WithCancel(context.Background())
1132	defer cancel()
1133
1134	h := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
1135	_, err := NewFloodSub(ctx, h, WithMessageAuthor("bogotr0n"))
1136	if err == nil {
1137		t.Fatal("expected error")
1138	}
1139}
1140
1141func TestPreconnectedNodes(t *testing.T) {
1142	ctx, cancel := context.WithCancel(context.Background())
1143	defer cancel()
1144	// If this test fails it may hang so set a timeout
1145	ctx, cancel = context.WithTimeout(ctx, time.Second*10)
1146	defer cancel()
1147
1148	// Create hosts
1149	h1 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
1150	h2 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
1151
1152	opts := []Option{WithDiscovery(&dummyDiscovery{})}
1153	// Setup first PubSub
1154	p1, err := NewFloodSub(ctx, h1, opts...)
1155	if err != nil {
1156		t.Fatal(err)
1157	}
1158
1159	// Connect the two hosts together
1160	connect(t, h2, h1)
1161
1162	// Setup the second DHT
1163	p2, err := NewFloodSub(ctx, h2, opts...)
1164	if err != nil {
1165		t.Fatal(err)
1166	}
1167
1168	// See if it works
1169	p2Topic, err := p2.Join("test")
1170	if err != nil {
1171		t.Fatal(err)
1172	}
1173
1174	p1Topic, err := p1.Join("test")
1175	if err != nil {
1176		t.Fatal(err)
1177	}
1178
1179	testPublish := func(publisher, receiver *Topic, msg []byte) {
1180		receiverSub, err := receiver.Subscribe()
1181		if err != nil {
1182			t.Fatal(err)
1183		}
1184
1185		if err := publisher.Publish(ctx, msg, WithReadiness(MinTopicSize(1))); err != nil {
1186			t.Fatal(err)
1187		}
1188
1189		m, err := receiverSub.Next(ctx)
1190		if err != nil {
1191			t.Fatal(err)
1192		}
1193
1194		if receivedData := m.GetData(); !bytes.Equal(receivedData, msg) {
1195			t.Fatalf("expected message %v, got %v", msg, receivedData)
1196		}
1197	}
1198
1199	// Test both directions since PubSub uses one directional streams
1200	testPublish(p1Topic, p2Topic, []byte("test1-to-2"))
1201	testPublish(p1Topic, p2Topic, []byte("test2-to-1"))
1202}
1203
1204func TestDedupInboundStreams(t *testing.T) {
1205	ctx, cancel := context.WithCancel(context.Background())
1206	defer cancel()
1207
1208	h1 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
1209	h2 := bhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
1210
1211	_, err := NewFloodSub(ctx, h1)
1212	if err != nil {
1213		t.Fatal(err)
1214	}
1215
1216	// Connect the two hosts together
1217	connect(t, h2, h1)
1218
1219	// open a few streams and make sure all but the last one get reset
1220	s1, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
1221	if err != nil {
1222		t.Fatal(err)
1223	}
1224	time.Sleep(100 * time.Millisecond)
1225
1226	s2, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
1227	if err != nil {
1228		t.Fatal(err)
1229	}
1230	time.Sleep(100 * time.Millisecond)
1231
1232	s3, err := h2.NewStream(ctx, h1.ID(), FloodSubID)
1233	if err != nil {
1234		t.Fatal(err)
1235	}
1236	time.Sleep(100 * time.Millisecond)
1237
1238	// check that s1 and s2 have been reset
1239	_, err = s1.Read([]byte{0})
1240	if err == nil {
1241		t.Fatal("expected s1 to be reset")
1242	}
1243
1244	_, err = s2.Read([]byte{0})
1245	if err == nil {
1246		t.Fatal("expected s2 to be reset")
1247	}
1248
1249	// check that s3 is readable and simply times out
1250	s3.SetReadDeadline(time.Now().Add(time.Millisecond))
1251	_, err = s3.Read([]byte{0})
1252	err2, ok := err.(interface{ Timeout() bool })
1253	if !ok || !err2.Timeout() {
1254		t.Fatal(err)
1255	}
1256}
1257