1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package libkb
5
6import (
7	"sync"
8	"testing"
9	"time"
10
11	"crypto/rand"
12
13	"github.com/keybase/client/go/kex2"
14)
15
16type ktester struct {
17	sender   kex2.DeviceID
18	receiver kex2.DeviceID
19	I        kex2.SessionID
20	seqno    kex2.Seqno
21}
22
23func newKtester() *ktester {
24	kt := &ktester{}
25	if _, err := rand.Read(kt.sender[:]); err != nil {
26		panic(err)
27	}
28	if _, err := rand.Read(kt.receiver[:]); err != nil {
29		panic(err)
30	}
31	if _, err := rand.Read(kt.I[:]); err != nil {
32		panic(err)
33	}
34
35	return kt
36}
37
38func (k *ktester) post(mr kex2.MessageRouter, b []byte) error {
39	k.seqno++
40	return mr.Post(k.I, k.sender, k.seqno, b)
41}
42
43func (k *ktester) get(mr kex2.MessageRouter, low kex2.Seqno, poll time.Duration) ([][]byte, error) {
44	return mr.Get(k.I, k.receiver, low, poll)
45}
46
47func TestKex2Router(t *testing.T) {
48	tc := SetupTest(t, "kex2 router", 1)
49	defer tc.Cleanup()
50
51	mr := NewKexRouter(NewMetaContextTODO(tc.G))
52	kt := newKtester()
53
54	m1 := "hello everybody"
55	m2 := "goodbye everybody"
56	m3 := "plaid shirt"
57
58	// test send 2 messages
59	if err := kt.post(mr, []byte(m1)); err != nil {
60		t.Fatal(err)
61	}
62
63	if err := kt.post(mr, []byte(m2)); err != nil {
64		t.Fatal(err)
65	}
66
67	// test receive 2 messages
68	msgs, err := kt.get(mr, 0, 100*time.Millisecond)
69	if err != nil {
70		t.Fatal(err)
71	}
72	if len(msgs) != 2 {
73		t.Fatalf("number of messages: %d, expected 2", len(msgs))
74	}
75	if string(msgs[0]) != m1 {
76		t.Errorf("message 0: %q, expected %q", msgs[0], m1)
77	}
78	if string(msgs[1]) != m2 {
79		t.Errorf("message 1: %q, expected %q", msgs[1], m2)
80	}
81
82	// test calling receive before send
83	var wg sync.WaitGroup
84	wg.Add(1)
85	go func() {
86		defer wg.Done()
87		var merr error
88		// Very large timeout, for the benefit of CI, which may be slow
89		msgs, merr = kt.get(mr, 3, 10*time.Second)
90		if merr != nil {
91			t.Errorf("receive error: %s", merr)
92		}
93	}()
94
95	time.Sleep(3 * time.Millisecond)
96	if err := kt.post(mr, []byte(m3)); err != nil {
97		t.Fatal(err)
98	}
99
100	wg.Wait()
101	if len(msgs) != 1 {
102		t.Fatalf("number of messages: %d, expected 1", len(msgs))
103	}
104	if string(msgs[0]) != m3 {
105		t.Errorf("message: %q, expected %q", msgs[0], m3)
106		t.Errorf("Full message vector was: %v\n", msgs)
107	}
108
109	// test no messages ready
110	msgs, err = kt.get(mr, 4, 1*time.Millisecond)
111	if err != nil {
112		t.Fatal(err)
113	}
114	if len(msgs) != 0 {
115		t.Errorf("number of messages: %d, expected 0", len(msgs))
116	}
117}
118