1// Copyright 2015 Keybase, Inc. All rights reserved. Use of
2// this source code is governed by the included BSD license.
3
4package kex2
5
6import (
7	"crypto/hmac"
8	"crypto/rand"
9	"crypto/sha256"
10	"errors"
11	"fmt"
12	"io"
13	"net"
14	"sync"
15	"time"
16
17	"github.com/keybase/go-codec/codec"
18	"golang.org/x/crypto/nacl/secretbox"
19	"golang.org/x/net/context"
20)
21
22// DeviceID is a 16-byte identifier that each side of key exchange has. It's
23// used primarily to tell sender from receiver.
24type DeviceID [16]byte
25
26// SessionID is a 32-byte session identifier that's derived from the shared
27// session secret. It's used to route messages on the server side.
28type SessionID [32]byte
29
30// SecretLen is the number of bytes in the secret.
31const SecretLen = 32
32
33// Secret is the 32-byte shared secret identifier
34type Secret [SecretLen]byte
35
36// Seqno increments on every message sent from a Kex sender.
37type Seqno uint32
38
39// Eq returns true if the two device IDs are equal
40func (d DeviceID) Eq(d2 DeviceID) bool {
41	return hmac.Equal(d[:], d2[:])
42}
43
44// Eq returns true if the two session IDs are equal
45func (s SessionID) Eq(s2 SessionID) bool {
46	return hmac.Equal(s[:], s2[:])
47}
48
49// MessageRouter is a stateful message router that will be implemented by
50// JSON/REST calls to the Keybase API server.
51type MessageRouter interface {
52
53	// Post a message. Message will always be non-nil and non-empty.
54	// Even for an EOF, the empty buffer is encrypted via SecretBox,
55	// so the buffer posted to the server will have data.
56	Post(I SessionID, sender DeviceID, seqno Seqno, msg []byte) error
57
58	// Get messages on the channel.  Only poll for `poll` milliseconds. If the timeout
59	// elapses without any data ready, then just return an empty result, with nil error.
60	// Several messages can be returned at once, which should be processed in serial.
61	// They are guaranteed to be in order; otherwise, there was an issue.
62	// Get() should only return a non-nil error if there was an HTTPS or TCP-level error.
63	// Application-level errors like EOF or no data ready are handled by modulating
64	// the `msgs` result.
65	Get(I SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) (msg [][]byte, err error)
66}
67
68// Conn is a struct that obeys the net.Conn interface. It establishes a session abstraction
69// over a message channel bounced off the Keybase API server, applying the appropriate
70// e2e encryption/MAC'ing.
71type Conn struct {
72	router    MessageRouter
73	secret    Secret
74	sessionID SessionID
75	deviceID  DeviceID
76
77	// Protects the read path. There should only be one reader outstanding at once.
78	readMutex    sync.Mutex
79	readSeqno    Seqno
80	readDeadline time.Time
81	readTimeout  time.Duration
82	bufferedMsgs [][]byte
83
84	// Protects the write path. There should only be one writer outstanding at once.
85	writeMutex sync.Mutex
86	writeSeqno Seqno
87
88	// Protects the pollLoopRunning mutex. We expose this mainly for testing purposes
89	pollLoopRunningMutex sync.Mutex
90	pollLoopRunning      bool
91
92	// Protects the setting of error states. Only one thread should be setting or
93	// accessing these errors at a time.
94	errMutex sync.Mutex
95	readErr  error
96	writeErr error
97	closed   bool
98
99	ctx  context.Context
100	lctx LogContext
101}
102
103const sessionIDText = "Kex v2 Session ID"
104
105// NewConn establishes a Kex session based on the given secret. Will work for
106// both ends of the connection, regardless of which order the two started
107// their connection. Will communicate with the other end via the given message router.
108// You can specify an optional timeout to cancel any reads longer than that timeout.
109func NewConn(ctx context.Context, lctx LogContext, r MessageRouter, s Secret, d DeviceID, readTimeout time.Duration) (con net.Conn, err error) {
110	mac := hmac.New(sha256.New, s[:])
111	_, err = mac.Write([]byte(sessionIDText))
112	if err != nil {
113		return nil, err
114	}
115	tmp := mac.Sum(nil)
116	var sessionID SessionID
117	copy(sessionID[:], tmp)
118	ret := &Conn{
119		router:      r,
120		secret:      s,
121		sessionID:   sessionID,
122		deviceID:    d,
123		readSeqno:   0,
124		readTimeout: readTimeout,
125		writeSeqno:  0,
126		ctx:         ctx,
127		lctx:        lctx,
128	}
129	return ret, nil
130}
131
132// TimedoutError is for operations that timed out; for instance, if no read
133// data was available before the deadline.
134type timedoutError struct{}
135
136// Error returns the string representation of this error
137func (t timedoutError) Error() string { return "operation timed out" }
138
139// Temporary returns if the error is retryable
140func (t timedoutError) Temporary() bool { return true }
141
142// Timeout returns if this error is a timeout
143func (t timedoutError) Timeout() bool { return true }
144
145// ErrTimedOut is the signleton error we use if the operation timedout.
146var ErrTimedOut net.Error = timedoutError{}
147
148// ErrUnimplemented indicates the given method isn't implemented
149var ErrUnimplemented = errors.New("unimplemented")
150
151// ErrBadMetadata indicates that the metadata outside the encrypted message
152// didn't match what was inside.
153var ErrBadMetadata = errors.New("bad metadata")
154
155// ErrBadDecryption indicates that a ciphertext failed to decrypt or MAC properly
156var ErrDecryption = errors.New("decryption failed")
157
158// ErrNotEnoughRandomness indicates that encryption failed due to insufficient
159// randomness
160var ErrNotEnoughRandomness = errors.New("not enough random data")
161
162// ErrWrongSession indicates that the given session didn't match the
163// clients expectations
164var ErrWrongSession = errors.New("got message for wrong Session ID")
165
166// ErrSelfReceive indicates that the client received a message sent by
167// itself, which should never happen
168var ErrSelfRecieve = errors.New("got message back that we sent")
169
170// ErrAgain indicates that no data was available to read, but the
171// reader was in non-blocking mode, so to try again later.
172var ErrAgain = errors.New("no data were ready to read")
173
174// ErrBadSecret indicates that the secret received was invalid.
175var ErrBadSecret = errors.New("bad secret")
176
177// ErrHelloTimeout indicates that the Hello() part of the
178// protocol timed out.  Most likely due to an incorrect
179// secret phrase from the user.
180var ErrHelloTimeout = errors.New("hello timeout")
181
182// ErrBadPacketSequence indicates that packets arrived out of order from the
183// server (which they shouldn't).
184type ErrBadPacketSequence struct {
185	SessionID     SessionID
186	SenderID      DeviceID
187	ReceivedSeqno Seqno
188	PrevSeqno     Seqno
189}
190
191func (e ErrBadPacketSequence) Error() string {
192	return fmt.Sprintf("Unexpected out-of-order packet arrival {SessionID: %v, SenderID: %v, ReceivedSeqno: %d, PrevSeqno: %d})",
193		e.SessionID, e.SenderID, e.ReceivedSeqno, e.PrevSeqno)
194}
195
196func (c *Conn) setReadError(e error) error {
197	c.errMutex.Lock()
198	c.readErr = e
199	c.errMutex.Unlock()
200	return e
201}
202
203func (c *Conn) setWriteError(e error) error {
204	c.errMutex.Lock()
205	c.writeErr = e
206	c.errMutex.Unlock()
207	return e
208}
209
210func (c *Conn) getErrorForWrite() error {
211	var err error
212	c.errMutex.Lock()
213	if c.readErr != nil && c.readErr != io.EOF {
214		err = c.readErr
215	} else if c.writeErr != nil {
216		err = c.writeErr
217	}
218	c.errMutex.Unlock()
219	return err
220}
221
222func (c *Conn) setClosed() {
223	c.errMutex.Lock()
224	c.closed = true
225	c.errMutex.Unlock()
226}
227
228func (c *Conn) getClosed() bool {
229	c.errMutex.Lock()
230	ret := c.closed
231	c.errMutex.Unlock()
232	return ret
233}
234
235func (c *Conn) getErrorForRead() error {
236	var err error
237	c.errMutex.Lock()
238	if c.readErr != nil {
239		err = c.readErr
240	} else if c.writeErr != nil && c.writeErr != io.EOF {
241		err = c.writeErr
242	}
243	c.errMutex.Unlock()
244	return err
245}
246
247func (c *Conn) setPollLoopRunning(b bool) {
248	c.pollLoopRunningMutex.Lock()
249	c.pollLoopRunning = b
250	c.pollLoopRunningMutex.Unlock()
251}
252
253type outerMsg struct {
254	_struct   bool      `codec:",toarray"` //nolint
255	SenderID  DeviceID  `codec:"senderID"`
256	SessionID SessionID `codec:"sessionID"`
257	Seqno     Seqno     `codec:"seqno"`
258	Nonce     [24]byte  `codec:"nonce"`
259	Payload   []byte    `codec:"payload"`
260}
261
262type innerMsg struct {
263	_struct   bool      `codec:",toarray"` //nolint
264	SenderID  DeviceID  `codec:"senderID"`
265	SessionID SessionID `codec:"sessionID"`
266	Seqno     Seqno     `codec:"seqno"`
267	Payload   []byte    `codec:"payload"`
268}
269
270func (c *Conn) decryptIncomingMessage(msg []byte) (int, error) {
271	var err error
272	mh := codec.MsgpackHandle{WriteExt: true}
273	dec := codec.NewDecoderBytes(msg, &mh)
274	var om outerMsg
275	err = dec.Decode(&om)
276	if err != nil {
277		c.lctx.Debug("Conn#decryptIncomingMessage: decoding failure: %s", err.Error())
278		return 0, err
279	}
280	var plaintext []byte
281	var ok bool
282	plaintext, ok = secretbox.Open(plaintext, om.Payload, &om.Nonce, (*[32]byte)(&c.secret))
283	if !ok {
284		return 0, ErrDecryption
285	}
286	dec = codec.NewDecoderBytes(plaintext, &mh)
287	var im innerMsg
288	err = dec.Decode(&im)
289	if err != nil {
290		return 0, err
291	}
292	if !om.SenderID.Eq(im.SenderID) || !om.SessionID.Eq(im.SessionID) || om.Seqno != im.Seqno {
293		return 0, ErrBadMetadata
294	}
295	if !im.SessionID.Eq(c.sessionID) {
296		return 0, ErrWrongSession
297	}
298	if im.SenderID.Eq(c.deviceID) {
299		return 0, ErrSelfRecieve
300	}
301
302	if im.Seqno != c.readSeqno+1 {
303		return 0, ErrBadPacketSequence{im.SessionID, im.SenderID, im.Seqno, c.readSeqno}
304	}
305	c.readSeqno = im.Seqno
306
307	c.bufferedMsgs = append(c.bufferedMsgs, im.Payload)
308	return len(im.Payload), nil
309}
310
311func (c *Conn) decryptIncomingMessages(msgs [][]byte) (int, error) {
312	var ret int
313	for _, msg := range msgs {
314		n, e := c.decryptIncomingMessage(msg)
315		if e != nil {
316			return ret, e
317		}
318		ret += n
319	}
320	return ret, nil
321}
322
323func (c *Conn) readBufferedMsgsIntoBytes(out []byte) (int, error) {
324	p := 0
325
326	// If no buffered messages, then return that we didn't pull any
327	// new data from the server.
328	if len(c.bufferedMsgs) == 0 {
329		return 0, nil
330	}
331
332	// Any empty buffer signals an EOF condition
333	if len(c.bufferedMsgs[0]) == 0 {
334		c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition")
335		return 0, io.EOF
336	}
337
338	for p < len(out) {
339		rem := len(out) - p
340		if len(c.bufferedMsgs) > 0 {
341			front := c.bufferedMsgs[0]
342			n := len(front)
343
344			// An empty buffer signifies that the other side wanted
345			// and EOF condition. However, we shouldn't return an EOF
346			// if we've read anything, this time through.
347			if n == 0 {
348				var err error
349				if p == 0 {
350					c.lctx.Debug("conn#readBufferedMsgsIntoBytes: empty buffer signaling EOF condition (after consume loop)")
351					err = io.EOF
352				}
353				return p, err
354			}
355
356			if rem < n {
357				n = rem
358				copy(out[p:(p+n)], front[0:n])
359				front = front[n:]
360				if len(front) == 0 {
361					// Be careful not to recycle an empty buffer into the
362					// list of buffered messages, since that has special
363					// significance (see above).
364					c.bufferedMsgs = c.bufferedMsgs[1:]
365				} else {
366					c.bufferedMsgs[0] = front
367				}
368			} else {
369				copy(out[p:(p+n)], front)
370				c.bufferedMsgs = c.bufferedMsgs[1:]
371			}
372
373			p += n
374		} else {
375			break
376		}
377	}
378	return p, nil
379}
380
381func (c *Conn) pollLoop(poll time.Duration) (msgs [][]byte, err error) {
382
383	var totalWaitTime time.Duration
384
385	c.setPollLoopRunning(true)
386	defer c.setPollLoopRunning(false)
387
388	start := time.Now()
389	for {
390		newPoll := poll - totalWaitTime
391		msgs, err = c.router.Get(c.sessionID, c.deviceID, c.readSeqno+1, newPoll)
392		totalWaitTime = time.Since(start)
393		if err != nil || len(msgs) > 0 || totalWaitTime >= poll || c.getClosed() {
394			return
395		}
396
397		select {
398		case <-c.ctx.Done():
399			return nil, ErrCanceled
400		default:
401		}
402	}
403}
404
405// Read data from the connection, returning plaintext data if all
406// cryptographic checks passed. Obeys the `net.Conn` interface.
407// Returns the number of bytes read into the output buffer.
408func (c *Conn) Read(out []byte) (n int, err error) {
409
410	c.readMutex.Lock()
411	defer c.readMutex.Unlock()
412
413	// The first error kills the whole stream
414	if err = c.getErrorForRead(); err != nil {
415		return 0, err
416	}
417	// First see if there's anything buffered, and read that
418	// out now.
419	if n, err = c.readBufferedMsgsIntoBytes(out); err != nil {
420		return 0, c.setReadError(err)
421	}
422	if n > 0 {
423		return n, nil
424	}
425
426	var poll time.Duration
427	if !c.readDeadline.IsZero() {
428		poll = time.Until(c.readDeadline)
429		if poll.Nanoseconds() < 0 {
430			return 0, c.setReadError(ErrTimedOut)
431		}
432	} else {
433		poll = c.readTimeout
434	}
435
436	var msgs [][]byte
437	msgs, err = c.pollLoop(poll)
438
439	if err != nil {
440		return 0, c.setReadError(err)
441	}
442	if _, err = c.decryptIncomingMessages(msgs); err != nil {
443		return 0, c.setReadError(err)
444	}
445	if n, err = c.readBufferedMsgsIntoBytes(out); err != nil {
446		return 0, c.setReadError(err)
447	}
448
449	if n == 0 {
450		switch {
451		case c.getClosed():
452			c.lctx.Debug("conn#Read: EOF since connection was closed")
453			err = io.EOF
454		case poll > 0:
455			err = ErrTimedOut
456		default:
457			err = ErrAgain
458		}
459	}
460
461	return n, err
462}
463
464func (c *Conn) encryptOutgoingMessage(seqno Seqno, buf []byte) (ret []byte, err error) {
465	var nonce [24]byte
466	var n int
467
468	if n, err = rand.Read(nonce[:]); err != nil {
469		return nil, err
470	} else if n != 24 {
471		return nil, ErrNotEnoughRandomness
472	}
473	im := innerMsg{
474		SenderID:  c.deviceID,
475		SessionID: c.sessionID,
476		Seqno:     seqno,
477		Payload:   buf,
478	}
479	mh := codec.MsgpackHandle{WriteExt: true}
480	var imPacked []byte
481	enc := codec.NewEncoderBytes(&imPacked, &mh)
482	if err = enc.Encode(im); err != nil {
483		return nil, err
484	}
485	ciphertext := secretbox.Seal(nil, imPacked, &nonce, (*[32]byte)(&c.secret))
486
487	om := outerMsg{
488		SenderID:  c.deviceID,
489		SessionID: c.sessionID,
490		Seqno:     seqno,
491		Nonce:     nonce,
492		Payload:   ciphertext,
493	}
494	enc = codec.NewEncoderBytes(&ret, &mh)
495	if err = enc.Encode(om); err != nil {
496		return nil, err
497	}
498	return ret, nil
499}
500
501func (c *Conn) nextWriteSeqno() Seqno {
502	c.writeSeqno++
503	return c.writeSeqno
504}
505
506// Write data to the connection, encrypting and MAC'ing along the way.
507// Obeys the `net.Conn` interface
508func (c *Conn) Write(buf []byte) (n int, err error) {
509
510	c.writeMutex.Lock()
511	defer c.writeMutex.Unlock()
512
513	// Our protocol specifies that writing an empty buffer means "close"
514	// the connection.  We don't want callers of `Write` to do this by
515	// accident, we want them to call `Close()` explicitly. So short-circuit
516	// the write operation here for empty buffers.
517	if len(buf) == 0 {
518		return 0, nil
519	}
520
521	return c.writeWithLock(buf)
522}
523
524func (c *Conn) writeWithLock(buf []byte) (n int, err error) {
525
526	var ctext []byte
527
528	// The first error kills the whole stream
529	if err = c.getErrorForWrite(); err != nil {
530		return 0, err
531	}
532	seqno := c.nextWriteSeqno()
533
534	ctext, err = c.encryptOutgoingMessage(seqno, buf)
535	if err != nil {
536		return 0, c.setWriteError(err)
537	}
538
539	if err = c.router.Post(c.sessionID, c.deviceID, seqno, ctext); err != nil {
540		return 0, c.setWriteError(err)
541	}
542
543	return len(ctext), nil
544}
545
546// Close the connection to the server, sending an empty buffer via POST
547// through the `MessageRouter`. Fulfills the `net.Conn` interface
548func (c *Conn) Close() error {
549
550	c.writeMutex.Lock()
551	defer c.writeMutex.Unlock()
552
553	c.lctx.Debug("Conn#Close: all subsequent writes are EOFs")
554
555	// set closed so that the read loop will bail out above
556	c.setClosed()
557
558	// Write an empty buffer to signal EOF
559	if _, err := c.writeWithLock([]byte{}); err != nil {
560		return err
561	}
562
563	// All subsequent writes should fail.
564	_ = c.setWriteError(io.EOF)
565
566	return nil
567}
568
569// LocalAddr returns the local network address, fulfilling the `net.Conn interface`
570func (c *Conn) LocalAddr() (addr net.Addr) {
571	return
572}
573
574// RemoteAddr returns the remote network address, fulfilling the `net.Conn interface`
575func (c *Conn) RemoteAddr() (addr net.Addr) {
576	return
577}
578
579// SetDeadline sets the read and write deadlines associated
580// with the connection. It is equivalent to calling both
581// SetReadDeadline and SetWriteDeadline.
582//
583// A deadline is an absolute time after which I/O operations
584// fail with a timeout (see type Error) instead of
585// blocking. The deadline applies to all future I/O, not just
586// the immediately following call to Read or Write.
587//
588// An idle timeout can be implemented by repeatedly extending
589// the deadline after successful Read or Write calls.
590//
591// A zero value for t means I/O operations will not time out.
592func (c *Conn) SetDeadline(t time.Time) error {
593	return c.SetReadDeadline(t)
594}
595
596// SetReadDeadline sets the deadline for future Read calls.
597// A zero value for t means Read will not time out.
598func (c *Conn) SetReadDeadline(t time.Time) error {
599	c.readMutex.Lock()
600	c.readDeadline = t
601	c.readMutex.Unlock()
602	return nil
603}
604
605// SetWriteDeadline sets the deadline for future Write calls.
606// Even if write times out, it may return n > 0, indicating that
607// some of the data was successfully written.
608// A zero value for t means Write will not time out.
609// We're not implementing this feature for now, so make it an error
610// if we try to do so.
611func (c *Conn) SetWriteDeadline(t time.Time) error {
612	return ErrUnimplemented
613}
614