1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package kex2
5
6import (
7	"crypto/rand"
8	"encoding/hex"
9	"errors"
10	"io"
11	"strings"
12	"testing"
13	"time"
14
15	keybase1 "github.com/keybase/client/go/protocol/keybase1"
16	"github.com/keybase/go-framed-msgpack-rpc/rpc"
17	"golang.org/x/net/context"
18)
19
20const (
21	GoodProvisionee                  = 0
22	BadProvisioneeFailHello          = 1 << iota
23	BadProvisioneeFailDidCounterSign = 1 << iota
24	BadProvisioneeSlowHello          = 1 << iota
25	BadProvisioneeSlowDidCounterSign = 1 << iota
26	BadProvisioneeCancel             = 1 << iota
27)
28
29type mockProvisioner struct {
30	uid keybase1.UID
31}
32
33type mockProvisionee struct {
34	behavior int
35}
36
37func newMockProvisioner(t *testing.T) *mockProvisioner {
38	return &mockProvisioner{
39		uid: genUID(t),
40	}
41}
42
43type nullLogOutput struct {
44}
45
46func (n *nullLogOutput) Error(s string, args ...interface{})   {}
47func (n *nullLogOutput) Warning(s string, args ...interface{}) {}
48func (n *nullLogOutput) Info(s string, args ...interface{})    {}
49func (n *nullLogOutput) Debug(s string, args ...interface{})   {}
50func (n *nullLogOutput) Profile(s string, args ...interface{}) {}
51
52var _ rpc.LogOutput = (*nullLogOutput)(nil)
53
54func makeLogFactory() rpc.LogFactory {
55	if testing.Verbose() {
56		return nil
57	}
58	return rpc.NewSimpleLogFactory(&nullLogOutput{}, nil)
59}
60
61func genUID(t *testing.T) keybase1.UID {
62	uid := make([]byte, 8)
63	if _, err := rand.Read(uid); err != nil {
64		t.Fatalf("rand failed: %v\n", err)
65	}
66	return keybase1.UID(hex.EncodeToString(uid))
67}
68
69func genKeybase1DeviceID(t *testing.T) keybase1.DeviceID {
70	did := make([]byte, 16)
71	if _, err := rand.Read(did); err != nil {
72		t.Fatalf("rand failed: %v\n", err)
73	}
74	return keybase1.DeviceID(hex.EncodeToString(did))
75}
76
77func newMockProvisionee(t *testing.T, behavior int) *mockProvisionee {
78	return &mockProvisionee{behavior}
79}
80
81func (mp *mockProvisioner) GetLogFactory() rpc.LogFactory {
82	return makeLogFactory()
83}
84
85func (mp *mockProvisioner) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage {
86	return &rpc.DummyInstrumentationStorage{}
87}
88
89func (mp *mockProvisioner) CounterSign(input keybase1.HelloRes) (output []byte, err error) {
90	output = []byte(string(input))
91	return
92}
93
94func (mp *mockProvisioner) CounterSign2(input keybase1.Hello2Res) (output keybase1.DidCounterSign2Arg, err error) {
95	output.Sig, err = mp.CounterSign(input.SigPayload)
96	return
97}
98
99func (mp *mockProvisioner) GetHelloArg() (res keybase1.HelloArg, err error) {
100	res.Uid = mp.uid
101	return res, err
102}
103func (mp *mockProvisioner) GetHello2Arg() (res keybase1.Hello2Arg, err error) {
104	res.Uid = mp.uid
105	return res, err
106}
107
108func (mp *mockProvisionee) GetLogFactory() rpc.LogFactory {
109	return makeLogFactory()
110}
111
112func (mp *mockProvisionee) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage {
113	return &rpc.DummyInstrumentationStorage{}
114}
115
116var ErrHandleHello = errors.New("handle hello failure")
117var ErrHandleDidCounterSign = errors.New("handle didCounterSign failure")
118var testTimeout = time.Duration(500) * time.Millisecond
119
120func (mp *mockProvisionee) HandleHello2(ctx context.Context, arg2 keybase1.Hello2Arg) (res keybase1.Hello2Res, err error) {
121	arg1 := keybase1.HelloArg{
122		Uid:     arg2.Uid,
123		SigBody: arg2.SigBody,
124	}
125	res.SigPayload, err = mp.HandleHello(ctx, arg1)
126	return res, err
127}
128
129func (mp *mockProvisionee) HandleHello(_ context.Context, arg keybase1.HelloArg) (res keybase1.HelloRes, err error) {
130	if (mp.behavior & BadProvisioneeSlowHello) != 0 {
131		time.Sleep(testTimeout * 8)
132	}
133	if (mp.behavior & BadProvisioneeFailHello) != 0 {
134		err = ErrHandleHello
135		return
136	}
137	res = keybase1.HelloRes(arg.SigBody)
138	return
139}
140
141func (mp *mockProvisionee) HandleDidCounterSign(_ context.Context, _ []byte) error {
142	if (mp.behavior & BadProvisioneeSlowDidCounterSign) != 0 {
143		time.Sleep(testTimeout * 8)
144	}
145	if (mp.behavior & BadProvisioneeFailDidCounterSign) != 0 {
146		return ErrHandleDidCounterSign
147	}
148	return nil
149}
150
151func (mp *mockProvisionee) HandleDidCounterSign2(ctx context.Context, arg keybase1.DidCounterSign2Arg) error {
152	return mp.HandleDidCounterSign(ctx, arg.Sig)
153}
154
155func testProtocolXWithBehavior(t *testing.T, provisioneeBehavior int) (results [2]error) {
156
157	timeout := testTimeout
158	router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout)
159
160	s2 := genSecret(t)
161
162	ch := make(chan error, 3)
163
164	secretCh := make(chan Secret)
165
166	ctx, cancelFn := context.WithCancel(context.Background())
167
168	testLogCtx, cleanup := newTestLogCtx(t)
169	defer cleanup()
170
171	// Run the provisioner
172	go func() {
173		err := RunProvisioner(ProvisionerArg{
174			KexBaseArg: KexBaseArg{
175				Ctx:           ctx,
176				LogCtx:        testLogCtx,
177				Mr:            router,
178				Secret:        genSecret(t),
179				DeviceID:      genKeybase1DeviceID(t),
180				SecretChannel: secretCh,
181				Timeout:       timeout,
182			},
183			Provisioner: newMockProvisioner(t),
184		})
185		ch <- err
186	}()
187
188	// Run the privisionee
189	go func() {
190		err := RunProvisionee(ProvisioneeArg{
191			KexBaseArg: KexBaseArg{
192				Ctx:           context.Background(),
193				LogCtx:        testLogCtx,
194				Mr:            router,
195				Secret:        s2,
196				DeviceID:      genKeybase1DeviceID(t),
197				SecretChannel: make(chan Secret),
198				Timeout:       timeout,
199			},
200			Provisionee: newMockProvisionee(t, provisioneeBehavior),
201		})
202		ch <- err
203	}()
204
205	if (provisioneeBehavior & BadProvisioneeCancel) != 0 {
206		go func() {
207			time.Sleep(testTimeout / 20)
208			cancelFn()
209		}()
210	}
211
212	secretCh <- s2
213
214	for i := 0; i < 2; i++ {
215		if e, eof := <-ch; !eof {
216			t.Fatalf("got unexpected channel close (try %d)", i)
217		} else if e != nil {
218			results[i] = e
219		}
220	}
221
222	return results
223}
224
225func TestFullProtocolXSuccess(t *testing.T) {
226	results := testProtocolXWithBehavior(t, GoodProvisionee)
227	for i, e := range results {
228		if e != nil {
229			t.Fatalf("Bad error %d: %v", i, e)
230		}
231	}
232}
233
234// Since errors are exported as strings, then we should just test that the
235// right kind of error was specified
236func eeq(e1, e2 error) bool {
237	return e1 != nil && e1.Error() == e2.Error()
238}
239
240// errHasSuffix makes sure that err's string has errSuffix's string as
241// a suffix. This is necessary as go-codec prepends stuff to any
242// errors it catches.
243func errHasSuffix(err, errSuffix error) bool {
244	return err != nil && strings.HasSuffix(err.Error(), errSuffix.Error())
245}
246
247func TestFullProtocolXProvisioneeFailHello(t *testing.T) {
248	results := testProtocolXWithBehavior(t, BadProvisioneeFailHello)
249	if !eeq(results[0], ErrHandleHello) {
250		t.Fatalf("Bad error 0: %v", results[0])
251	}
252	if !eeq(results[1], ErrHandleHello) {
253		t.Fatalf("Bad error 1: %v", results[1])
254	}
255}
256
257func TestFullProtocolXProvisioneeFailDidCounterSign(t *testing.T) {
258	results := testProtocolXWithBehavior(t, BadProvisioneeFailDidCounterSign)
259	if !eeq(results[0], ErrHandleDidCounterSign) {
260		t.Fatalf("Bad error 0: %v", results[0])
261	}
262	if !eeq(results[1], ErrHandleDidCounterSign) {
263		t.Fatalf("Bad error 1: %v", results[1])
264	}
265}
266
267func TestFullProtocolXProvisioneeSlowHello(t *testing.T) {
268	results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello)
269	for i, e := range results {
270		if !errHasSuffix(e, ErrTimedOut) && !errHasSuffix(e, io.EOF) && !errHasSuffix(e, ErrHelloTimeout) {
271			t.Fatalf("Bad error %d: %v", i, e)
272		}
273	}
274}
275
276func TestFullProtocolXProvisioneeSlowHelloWithCancel(t *testing.T) {
277	results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello|BadProvisioneeCancel)
278	for i, e := range results {
279		if !eeq(e, ErrCanceled) && !eeq(e, io.EOF) {
280			t.Fatalf("Bad error %d: %v", i, e)
281		}
282	}
283}
284
285func TestFullProtocolY(t *testing.T) {
286
287	timeout := time.Duration(60) * time.Second
288	router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout)
289
290	s1 := genSecret(t)
291
292	ch := make(chan error, 3)
293
294	secretCh := make(chan Secret)
295	testLogCtx, cleanup := newTestLogCtx(t)
296	defer cleanup()
297
298	// Run the provisioner
299	go func() {
300		err := RunProvisioner(ProvisionerArg{
301			KexBaseArg: KexBaseArg{
302				Ctx:           context.TODO(),
303				LogCtx:        testLogCtx,
304				Mr:            router,
305				Secret:        s1,
306				DeviceID:      genKeybase1DeviceID(t),
307				SecretChannel: make(chan Secret),
308				Timeout:       timeout,
309			},
310			Provisioner: newMockProvisioner(t),
311		})
312		ch <- err
313	}()
314
315	// Run the provisionee
316	go func() {
317		err := RunProvisionee(ProvisioneeArg{
318			KexBaseArg: KexBaseArg{
319				Ctx:           context.TODO(),
320				LogCtx:        testLogCtx,
321				Mr:            router,
322				Secret:        genSecret(t),
323				DeviceID:      genKeybase1DeviceID(t),
324				SecretChannel: secretCh,
325				Timeout:       timeout,
326			},
327			Provisionee: newMockProvisionee(t, GoodProvisionee),
328		})
329		ch <- err
330	}()
331
332	secretCh <- s1
333
334	for i := 0; i < 2; i++ {
335		if e, eof := <-ch; !eof {
336			t.Fatalf("got unexpected channel close (try %d)", i)
337		} else if e != nil {
338			t.Fatalf("Unexpected error (receive %d): %v", i, e)
339		}
340	}
341
342}
343