1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package kex2
5
6import (
7	"bytes"
8	"crypto/rand"
9	"io"
10	"net"
11	"runtime"
12	"strings"
13	"sync"
14	"testing"
15	"time"
16
17	"github.com/stretchr/testify/require"
18	"golang.org/x/net/context"
19)
20
21type message struct {
22	seqno Seqno
23	msg   []byte
24}
25
26type simplexSession struct {
27	ch chan message
28}
29
30var zeroDeviceID DeviceID
31
32func (d DeviceID) isZero() bool {
33	return d.Eq(zeroDeviceID)
34}
35
36func newSimplexSession() *simplexSession {
37	return &simplexSession{
38		ch: make(chan message, 100),
39	}
40}
41
42type session struct {
43	id              SessionID
44	devices         [2]DeviceID
45	simplexSessions [2](*simplexSession)
46}
47
48func newSession(i SessionID) *session {
49	sess := &session{id: i}
50	for j := 0; j < 2; j++ {
51		sess.simplexSessions[j] = newSimplexSession()
52	}
53	return sess
54}
55
56func (s *session) getDeviceNumber(d DeviceID) int {
57	if s.devices[0].Eq(d) {
58		return 0
59	}
60	if s.devices[0].isZero() {
61		s.devices[0] = d
62		return 0
63	}
64	s.devices[1] = d
65	return 1
66}
67
68type mockRouter struct {
69	behavior int
70	maxPoll  time.Duration
71
72	sessionMutex sync.Mutex
73	sessions     map[SessionID]*session
74}
75
76const (
77	GoodRouter                   = 0
78	BadRouterCorruptedSession    = 1 << iota
79	BadRouterCorruptedSender     = 1 << iota
80	BadRouterCorruptedCiphertext = 1 << iota
81	BadRouterReorder             = 1 << iota
82	BadRouterDrop                = 1 << iota
83)
84
85func corruptMessage(behavior int, msg []byte) {
86	if (behavior & BadRouterCorruptedSession) != 0 {
87		msg[23] ^= 0x80
88	}
89	if (behavior & BadRouterCorruptedSender) != 0 {
90		msg[10] ^= 0x40
91	}
92	if (behavior & BadRouterCorruptedCiphertext) != 0 {
93		msg[len(msg)-10] ^= 0x01
94	}
95}
96
97func newMockRouterWithBehavior(b int) *mockRouter {
98	return &mockRouter{
99		behavior: b,
100		sessions: make(map[SessionID]*session),
101	}
102}
103
104func newMockRouterWithBehaviorAndMaxPoll(b int, mp time.Duration) *mockRouter {
105	return &mockRouter{
106		behavior: b,
107		maxPoll:  mp,
108		sessions: make(map[SessionID]*session),
109	}
110}
111
112func (ss *simplexSession) post(seqno Seqno, msg []byte) error {
113	ss.ch <- message{seqno, msg}
114	return nil
115}
116
117type lookupType int
118
119const (
120	bySender   lookupType = 0
121	byReceiver lookupType = 1
122)
123
124func (s *session) findOrMakeSimplexSession(sender DeviceID, lt lookupType) *simplexSession {
125	i := s.getDeviceNumber(sender)
126	if lt == byReceiver {
127		i = 1 - i
128	}
129	return s.simplexSessions[i]
130}
131
132func (mr *mockRouter) findOrMakeSimplexSession(i SessionID, sender DeviceID, lt lookupType) *simplexSession {
133	mr.sessionMutex.Lock()
134	defer mr.sessionMutex.Unlock()
135
136	sess, ok := mr.sessions[i]
137	if !ok {
138		sess = newSession(i)
139		mr.sessions[i] = sess
140	}
141	return sess.findOrMakeSimplexSession(sender, lt)
142}
143
144func (mr *mockRouter) Post(i SessionID, sender DeviceID, seqno Seqno, msg []byte) error {
145	ss := mr.findOrMakeSimplexSession(i, sender, bySender)
146	corruptMessage(mr.behavior, msg)
147	return ss.post(seqno, msg)
148}
149
150func (ss *simplexSession) get(seqno Seqno, poll time.Duration, behavior int) (ret [][]byte, err error) {
151	timeout := false
152	handleMessage := func(msg message) {
153		ret = append(ret, msg.msg)
154	}
155	if poll.Nanoseconds() > 0 {
156		select {
157		case msg := <-ss.ch:
158			handleMessage(msg)
159		case <-time.After(poll):
160			timeout = true
161		}
162	}
163	if !timeout {
164	loopMessages:
165		for {
166			select {
167			case msg := <-ss.ch:
168				handleMessage(msg)
169			default:
170				break loopMessages
171			}
172		}
173	}
174
175	if (behavior&BadRouterReorder) != 0 && len(ret) > 1 {
176		ret[0], ret[1] = ret[1], ret[0]
177	}
178	if (behavior&BadRouterDrop) != 0 && len(ret) > 1 {
179		ret = ret[1:]
180	}
181
182	return ret, err
183}
184
185func (mr *mockRouter) Get(i SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) ([][]byte, error) {
186	ss := mr.findOrMakeSimplexSession(i, receiver, byReceiver)
187	if mr.maxPoll > time.Duration(0) && poll > mr.maxPoll {
188		poll = mr.maxPoll
189	}
190	return ss.get(seqno, poll, mr.behavior)
191}
192
193func genSecret(t *testing.T) (ret Secret) {
194	_, err := rand.Read(ret[:])
195	if err != nil {
196		t.Fatal(err)
197	}
198	return ret
199}
200
201func genDeviceID(t *testing.T) (ret DeviceID) {
202	_, err := rand.Read(ret[:])
203	if err != nil {
204		t.Fatal(err)
205	}
206	return ret
207}
208
209type testLogCtx struct {
210	sync.Mutex
211	t *testing.T
212}
213
214func newTestLogCtx(t *testing.T) (ret *testLogCtx, closer func()) {
215	ret = &testLogCtx{t: t}
216	closer = func() {
217		ret.Lock()
218		defer ret.Unlock()
219		ret.t = nil
220	}
221	return ret, closer
222}
223
224func (t *testLogCtx) Debug(format string, args ...interface{}) {
225	t.Lock()
226	if t.t != nil {
227		t.t.Logf(format, args...)
228	}
229	t.Unlock()
230}
231
232func genNewConn(t *testLogCtx, mr MessageRouter, s Secret, d DeviceID, rt time.Duration) net.Conn {
233	ret, err := NewConn(context.TODO(), t, mr, s, d, rt)
234	if err != nil {
235		t.t.Fatal(err)
236	}
237	return ret
238}
239
240func genConnPair(t *testLogCtx, behavior int, readTimeout time.Duration) (c1 net.Conn, c2 net.Conn, d1 DeviceID, d2 DeviceID) {
241	r := newMockRouterWithBehavior(behavior)
242	s := genSecret(t.t)
243	d1 = genDeviceID(t.t)
244	d2 = genDeviceID(t.t)
245	c1 = genNewConn(t, r, s, d1, readTimeout)
246	c2 = genNewConn(t, r, s, d2, readTimeout)
247	return
248}
249
250func maybeDisableTest(t *testing.T) {
251	if runtime.GOOS == "windows" {
252		t.Skip()
253	}
254}
255
256func TestHello(t *testing.T) {
257	testLogCtx, cleanup := newTestLogCtx(t)
258	defer cleanup()
259	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
260	txt := []byte("hello friend")
261	if _, err := c1.Write(txt); err != nil {
262		t.Fatal(err)
263	}
264	buf := make([]byte, 100)
265	if n, err := c2.Read(buf); err != nil {
266		t.Fatal(err)
267	} else if n != len(txt) {
268		t.Fatal("bad read len")
269	} else if !bytes.Equal(buf[0:n], txt) {
270		t.Fatal("wrong message back")
271	}
272	txt2 := []byte("pong PONG pong PONG pong PONG")
273	if _, err := c2.Write(txt2); err != nil {
274		t.Fatal(err)
275	} else if n, err := c1.Read(buf); err != nil {
276		t.Fatal(err)
277	} else if n != len(txt2) {
278		t.Fatal("bad read len")
279	} else if !bytes.Equal(buf[0:n], txt2) {
280		t.Fatal("wrong ponged text")
281	}
282}
283
284func TestBadMetadata(t *testing.T) {
285	testLogCtx, cleanup := newTestLogCtx(t)
286	defer cleanup()
287
288	testBehavior := func(b int, wanted error) {
289		c1, c2, _, _ := genConnPair(testLogCtx, b, time.Duration(0))
290		txt := []byte("hello friend")
291		if _, err := c1.Write(txt); err != nil {
292			t.Fatal(err)
293		}
294		buf := make([]byte, 100)
295		if _, err := c2.Read(buf); err == nil {
296			t.Fatalf("behavior %d: wanted an error, didn't get one", b)
297		} else if err != wanted {
298			t.Fatalf("behavior %d: wanted error '%v', got '%v'", b, err, wanted)
299		}
300	}
301	testBehavior(BadRouterCorruptedSession, ErrBadMetadata)
302	testBehavior(BadRouterCorruptedSender, ErrBadMetadata)
303	testBehavior(BadRouterCorruptedCiphertext, ErrDecryption)
304}
305
306func TestReadDeadline(t *testing.T) {
307	testLogCtx, cleanup := newTestLogCtx(t)
308	defer cleanup()
309	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
310	wait := time.Duration(10) * time.Millisecond
311	err := c2.SetReadDeadline(time.Now().Add(wait))
312	require.NoError(t, err)
313	go func() {
314		time.Sleep(wait * 2)
315		_, _ = c1.Write([]byte("hello friend"))
316	}()
317	buf := make([]byte, 100)
318	_, err = c2.Read(buf)
319	if err != ErrTimedOut {
320		t.Fatalf("wanted a read timeout")
321	}
322}
323
324func TestReadTimeout(t *testing.T) {
325	testLogCtx, cleanup := newTestLogCtx(t)
326	defer cleanup()
327	wait := time.Duration(10) * time.Millisecond
328	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, wait)
329	go func() {
330		time.Sleep(wait * 2)
331		_, _ = c1.Write([]byte("hello friend"))
332	}()
333	buf := make([]byte, 100)
334	_, err := c2.Read(buf)
335	if err != ErrTimedOut {
336		t.Fatalf("wanted a read timeout")
337	}
338}
339
340func TestReadDelayedWrite(t *testing.T) {
341	maybeDisableTest(t)
342	testLogCtx, cleanup := newTestLogCtx(t)
343	defer cleanup()
344	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
345	wait := time.Duration(50) * time.Millisecond
346	err := c2.SetReadDeadline(time.Now().Add(wait))
347	require.NoError(t, err)
348	text := "hello friend"
349	go func() {
350		time.Sleep(wait / 32)
351		_, _ = c1.Write([]byte(text))
352	}()
353	buf := make([]byte, 100)
354	n, err := c2.Read(buf)
355	if err != nil {
356		t.Fatal(err)
357	}
358	if n != len(text) {
359		t.Fatalf("wrong read length")
360	}
361}
362
363func TestMultipleWritesOneRead(t *testing.T) {
364	maybeDisableTest(t)
365	testLogCtx, cleanup := newTestLogCtx(t)
366	defer cleanup()
367	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
368	msgs := []string{
369		"Alas, poor Yorick! I knew him, Horatio: a fellow",
370		"of infinite jest, of most excellent fancy: he hath",
371		"borne me on his back a thousand times; and now, how",
372		"abhorred in my imagination it is! my gorge rims at",
373		"it.",
374	}
375	for i, m := range msgs {
376		if i > 0 {
377			m = "\n" + m
378		}
379		if _, err := c1.Write([]byte(m)); err != nil {
380			t.Fatal(err)
381		}
382	}
383	buf := make([]byte, 1000)
384	if n, err := c2.Read(buf); err != nil {
385		t.Fatal(err)
386	} else if strings.Join(msgs, "\n") != string(buf[0:n]) {
387		t.Fatal("string mismatch")
388	}
389}
390
391func TestOneWriteMultipleReads(t *testing.T) {
392	testLogCtx, cleanup := newTestLogCtx(t)
393	defer cleanup()
394	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
395	msg := `Crows maunder on the petrified fairway.
396Absence! My heart grows tense
397as though a harpoon were sparring for the kill.`
398	if _, err := c1.Write([]byte(msg)); err != nil {
399		return
400	}
401	small := make([]byte, 3)
402	var buf []byte
403	for {
404		if n, err := c2.Read(small); err != nil && err != ErrAgain {
405			t.Fatal(err)
406		} else if n == 0 {
407			if err != ErrAgain {
408				t.Fatalf("exepcted ErrAgain if we read 0 bytes, but got %v", err)
409			}
410			break
411		} else {
412			buf = append(buf, small[0:n]...)
413		}
414	}
415	if string(buf) != msg {
416		t.Fatal("message mismatch")
417	}
418}
419
420func TestReorder(t *testing.T) {
421	testLogCtx, cleanup := newTestLogCtx(t)
422	defer cleanup()
423	c1, c2, _, _ := genConnPair(testLogCtx, BadRouterReorder, time.Duration(0))
424	msgs := []string{
425		"Alas, poor Yorick! I knew him, Horatio: a fellow",
426		"of infinite jest, of most excellent fancy: he hath",
427		"borne me on his back a thousand times; and now, how",
428		"abhorred in my imagination it is! my gorge rims at",
429		"it.",
430	}
431	for i, m := range msgs {
432		if i > 0 {
433			m = "\n" + m
434		}
435		if _, err := c1.Write([]byte(m)); err != nil {
436			t.Fatal(err)
437		}
438	}
439	buf := make([]byte, 1000)
440	_, err := c2.Read(buf)
441	if _, ok := err.(ErrBadPacketSequence); !ok {
442		t.Fatalf("expected an ErrBadPacketSequence; got %v", err)
443	}
444}
445
446func TestDrop(t *testing.T) {
447	testLogCtx, cleanup := newTestLogCtx(t)
448	defer cleanup()
449	c1, c2, _, _ := genConnPair(testLogCtx, BadRouterDrop, time.Duration(0))
450	msgs := []string{
451		"Alas, poor Yorick! I knew him, Horatio: a fellow",
452		"of infinite jest, of most excellent fancy: he hath",
453		"borne me on his back a thousand times; and now, how",
454		"abhorred in my imagination it is! my gorge rims at",
455		"it.",
456	}
457	for i, m := range msgs {
458		if i > 0 {
459			m = "\n" + m
460		}
461		if _, err := c1.Write([]byte(m)); err != nil {
462			t.Fatal(err)
463		}
464	}
465	buf := make([]byte, 1000)
466	_, err := c2.Read(buf)
467	if _, ok := err.(ErrBadPacketSequence); !ok {
468		t.Fatalf("expected an ErrBadPacketSequence; got %v", err)
469	}
470}
471
472func TestClose(t *testing.T) {
473	testLogCtx, cleanup := newTestLogCtx(t)
474	defer cleanup()
475	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(4)*time.Second)
476	msg := "Hello friend. I'm going to mic drop."
477	if _, err := c1.Write([]byte(msg)); err != nil {
478		t.Fatal(err)
479	}
480	if err := c1.Close(); err != nil {
481		t.Fatal(err)
482	}
483	buf := make([]byte, 1000)
484	if n, err := c2.Read(buf); err != nil {
485		t.Fatal(err)
486	} else if n != len(msg) {
487		t.Fatalf("short read: %d v %d: %v", n, len(msg), msg)
488	} else if string(buf[0:n]) != msg {
489		t.Fatal("wrong msg")
490	}
491
492	// Assert we get an EOF now and forever...
493	for i := 0; i < 3; i++ {
494		if n, err := c2.Read(buf); err != io.EOF {
495			t.Fatalf("expected EOF, but got err = %v", err)
496		} else if n != 0 {
497			t.Fatalf("Expected 0-length read, but got %d", n)
498		}
499	}
500}
501
502func TestErrAgain(t *testing.T) {
503	testLogCtx, cleanup := newTestLogCtx(t)
504	defer cleanup()
505	_, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
506	buf := make([]byte, 100)
507	if n, err := c2.Read(buf); err != ErrAgain {
508		t.Fatalf("wanted ErrAgain, but got err = %v", err)
509	} else if n != 0 {
510		t.Fatalf("Wanted 0 bytes back; got %d", n)
511	}
512}
513
514func TestPollLoopSuccess(t *testing.T) {
515	maybeDisableTest(t)
516
517	testLogCtx, cleanup := newTestLogCtx(t)
518	defer cleanup()
519
520	wait := time.Duration(100) * time.Millisecond
521	r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/128)
522	s := genSecret(t)
523	d1 := genDeviceID(t)
524	d2 := genDeviceID(t)
525	c1 := genNewConn(testLogCtx, r, s, d1, wait)
526	c2 := genNewConn(testLogCtx, r, s, d2, wait)
527
528	text := "poll for this, will you?"
529
530	go func() {
531		time.Sleep(wait / 32)
532		_, _ = c1.Write([]byte(text))
533	}()
534	buf := make([]byte, 100)
535	n, err := c2.Read(buf)
536	if err != nil {
537		t.Fatal(err)
538	}
539	if n != len(text) {
540		t.Fatalf("wrong read length")
541	}
542}
543
544func TestPollLoopTimeout(t *testing.T) {
545	maybeDisableTest(t)
546
547	testLogCtx, cleanup := newTestLogCtx(t)
548	defer cleanup()
549
550	wait := time.Duration(8) * time.Millisecond
551	r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/32)
552	s := genSecret(t)
553	d1 := genDeviceID(t)
554	d2 := genDeviceID(t)
555	c1 := genNewConn(testLogCtx, r, s, d1, wait)
556	c2 := genNewConn(testLogCtx, r, s, d2, wait)
557
558	text := "poll for this, will you?"
559
560	go func() {
561		time.Sleep(wait * 2)
562		_, _ = c1.Write([]byte(text))
563	}()
564	buf := make([]byte, 100)
565	if _, err := c2.Read(buf); err != ErrTimedOut {
566		t.Fatalf("Wanted ErrTimedOut; got %v", err)
567	}
568}
569