1package mint
2
3import (
4	"crypto/cipher"
5	"fmt"
6	"io"
7	"sync"
8)
9
10const (
11	sequenceNumberLen   = 8       // sequence number length
12	recordHeaderLenTLS  = 5       // record header length (TLS)
13	recordHeaderLenDTLS = 13      // record header length (DTLS)
14	maxFragmentLen      = 1 << 14 // max number of bytes in a record
15)
16
17type DecryptError string
18
19func (err DecryptError) Error() string {
20	return string(err)
21}
22
23type direction uint8
24
25const (
26	directionWrite = direction(1)
27	directionRead  = direction(2)
28)
29
30// struct {
31//     ContentType type;
32//     ProtocolVersion record_version [0301 for CH, 0303 for others]
33//     uint16 length;
34//     opaque fragment[TLSPlaintext.length];
35// } TLSPlaintext;
36type TLSPlaintext struct {
37	// Omitted: record_version (static)
38	// Omitted: length         (computed from fragment)
39	contentType RecordType
40	epoch       Epoch
41	seq         uint64
42	fragment    []byte
43}
44
45type cipherState struct {
46	epoch    Epoch       // DTLS epoch
47	ivLength int         // Length of the seq and nonce fields
48	seq      uint64      // Zero-padded sequence number
49	iv       []byte      // Buffer for the IV
50	cipher   cipher.AEAD // AEAD cipher
51}
52
53type RecordLayer struct {
54	sync.Mutex
55	label        string
56	direction    direction
57	version      uint16        // The current version number
58	conn         io.ReadWriter // The underlying connection
59	frame        *frameReader  // The buffered frame reader
60	nextData     []byte        // The next record to send
61	cachedRecord *TLSPlaintext // Last record read, cached to enable "peek"
62	cachedError  error         // Error on the last record read
63
64	cipher      *cipherState
65	readCiphers map[Epoch]*cipherState
66
67	datagram bool
68}
69
70type recordLayerFrameDetails struct {
71	datagram bool
72}
73
74func (d recordLayerFrameDetails) headerLen() int {
75	if d.datagram {
76		return recordHeaderLenDTLS
77	}
78	return recordHeaderLenTLS
79}
80
81func (d recordLayerFrameDetails) defaultReadLen() int {
82	return d.headerLen() + maxFragmentLen
83}
84
85func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) {
86	return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil
87}
88
89func newCipherStateNull() *cipherState {
90	return &cipherState{EpochClear, 0, 0, nil, nil}
91}
92
93func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) {
94	cipher, err := factory(key)
95	if err != nil {
96		return nil, err
97	}
98
99	return &cipherState{epoch, len(iv), 0, iv, cipher}, nil
100}
101
102func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer {
103	r := RecordLayer{}
104	r.label = ""
105	r.direction = dir
106	r.conn = conn
107	r.frame = newFrameReader(recordLayerFrameDetails{false})
108	r.cipher = newCipherStateNull()
109	r.version = tls10Version
110	return &r
111}
112
113func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer {
114	r := RecordLayer{}
115	r.label = ""
116	r.direction = dir
117	r.conn = conn
118	r.frame = newFrameReader(recordLayerFrameDetails{true})
119	r.cipher = newCipherStateNull()
120	r.readCiphers = make(map[Epoch]*cipherState, 0)
121	r.readCiphers[0] = r.cipher
122	r.datagram = true
123	return &r
124}
125
126func (r *RecordLayer) SetVersion(v uint16) {
127	r.version = v
128}
129
130func (r *RecordLayer) ResetClear(seq uint64) {
131	r.cipher = newCipherStateNull()
132	r.cipher.seq = seq
133}
134
135func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error {
136	cipher, err := newCipherStateAead(epoch, factory, key, iv)
137	if err != nil {
138		return err
139	}
140	r.cipher = cipher
141	if r.datagram && r.direction == directionRead {
142		r.readCiphers[epoch] = cipher
143	}
144	return nil
145}
146
147// TODO(ekr@rtfm.com): This is never used, which is a bug.
148func (r *RecordLayer) DiscardReadKey(epoch Epoch) {
149	if !r.datagram {
150		return
151	}
152
153	_, ok := r.readCiphers[epoch]
154	assert(ok)
155	delete(r.readCiphers, epoch)
156}
157
158func (c *cipherState) combineSeq(datagram bool) uint64 {
159	seq := c.seq
160	if datagram {
161		seq |= uint64(c.epoch) << 48
162	}
163	return seq
164}
165
166func (c *cipherState) computeNonce(seq uint64) []byte {
167	nonce := make([]byte, len(c.iv))
168	copy(nonce, c.iv)
169
170	s := seq
171
172	offset := len(c.iv)
173	for i := 0; i < 8; i++ {
174		nonce[(offset-i)-1] ^= byte(s & 0xff)
175		s >>= 8
176	}
177	logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce)
178
179	return nonce
180}
181
182func (c *cipherState) incrementSequenceNumber() {
183	if c.seq >= (1<<48 - 1) {
184		// Not allowed to let sequence number wrap.
185		// Instead, must renegotiate before it does.
186		// Not likely enough to bother. This is the
187		// DTLS limit.
188		panic("TLS: sequence number wraparound")
189	}
190	c.seq++
191}
192
193func (c *cipherState) overhead() int {
194	if c.cipher == nil {
195		return 0
196	}
197	return c.cipher.Overhead()
198}
199
200func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext {
201	assert(r.direction == directionWrite)
202	logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq)
203	// Expand the fragment to hold contentType, padding, and overhead
204	originalLen := len(pt.fragment)
205	plaintextLen := originalLen + 1 + padLen
206	ciphertextLen := plaintextLen + cipher.overhead()
207
208	// Assemble the revised plaintext
209	out := &TLSPlaintext{
210
211		contentType: RecordTypeApplicationData,
212		fragment:    make([]byte, ciphertextLen),
213	}
214	copy(out.fragment, pt.fragment)
215	out.fragment[originalLen] = byte(pt.contentType)
216	for i := 1; i <= padLen; i++ {
217		out.fragment[originalLen+i] = 0
218	}
219
220	// Encrypt the fragment
221	payload := out.fragment[:plaintextLen]
222	cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil)
223	return out
224}
225
226func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) {
227	assert(r.direction == directionRead)
228	logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq)
229	if len(pt.fragment) < r.cipher.overhead() {
230		msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead())
231		return nil, 0, DecryptError(msg)
232	}
233
234	decryptLen := len(pt.fragment) - r.cipher.overhead()
235	out := &TLSPlaintext{
236		contentType: pt.contentType,
237		fragment:    make([]byte, decryptLen),
238	}
239
240	// Decrypt
241	_, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil)
242	if err != nil {
243		logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt)
244		return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed")
245	}
246
247	// Find the padding boundary
248	padLen := 0
249	for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ {
250	}
251
252	// Transfer the content type
253	newLen := decryptLen - padLen - 1
254	out.contentType = RecordType(out.fragment[newLen])
255
256	// Truncate the message to remove contentType, padding, overhead
257	out.fragment = out.fragment[:newLen]
258	out.seq = seq
259	return out, padLen, nil
260}
261
262func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) {
263	var pt *TLSPlaintext
264	var err error
265
266	for {
267		pt, err = r.nextRecord(false)
268		if err == nil {
269			break
270		}
271		if !block || err != AlertWouldBlock {
272			return 0, err
273		}
274	}
275	return pt.contentType, nil
276}
277
278func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) {
279	pt, err := r.nextRecord(false)
280
281	// Consume the cached record if there was one
282	r.cachedRecord = nil
283	r.cachedError = nil
284
285	return pt, err
286}
287
288func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) {
289	pt, err := r.nextRecord(true)
290
291	// Consume the cached record if there was one
292	r.cachedRecord = nil
293	r.cachedError = nil
294
295	return pt, err
296}
297
298func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) {
299	cipher := r.cipher
300	if r.cachedRecord != nil {
301		logf(logTypeIO, "%s Returning cached record", r.label)
302		return r.cachedRecord, r.cachedError
303	}
304
305	// Loop until one of three things happens:
306	//
307	// 1. We get a frame
308	// 2. We try to read off the socket and get nothing, in which case
309	//    returnAlertWouldBlock
310	// 3. We get an error.
311	var err error
312	err = AlertWouldBlock
313	var header, body []byte
314
315	for err != nil {
316		if r.frame.needed() > 0 {
317			buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen)
318			n, err := r.conn.Read(buf)
319			if err != nil {
320				logf(logTypeIO, "%s Error reading, %v", r.label, err)
321				return nil, err
322			}
323
324			if n == 0 {
325				return nil, AlertWouldBlock
326			}
327
328			logf(logTypeIO, "%s Read %v bytes", r.label, n)
329
330			buf = buf[:n]
331			r.frame.addChunk(buf)
332		}
333
334		header, body, err = r.frame.process()
335		// Loop around onAlertWouldBlock to see if some
336		// data is now available.
337		if err != nil && err != AlertWouldBlock {
338			return nil, err
339		}
340	}
341
342	pt := &TLSPlaintext{}
343	// Validate content type
344	switch RecordType(header[0]) {
345	default:
346		return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0])
347	case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck:
348		pt.contentType = RecordType(header[0])
349	}
350
351	// Validate version
352	if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) {
353		return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2])
354	}
355
356	// Validate size < max
357	size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1])
358
359	if size > maxFragmentLen+256 {
360		return nil, fmt.Errorf("tls.record: Ciphertext size too big")
361	}
362
363	pt.fragment = make([]byte, size)
364	copy(pt.fragment, body)
365
366	// TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data.
367
368	// Attempt to decrypt fragment
369	seq := cipher.seq
370	if r.datagram {
371		// TODO(ekr@rtfm.com): Handle duplicates.
372		seq, _ = decodeUint(header[3:11], 8)
373		epoch := Epoch(seq >> 48)
374
375		// Look up the cipher suite from the epoch
376		c, ok := r.readCiphers[epoch]
377		if !ok {
378			logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch)
379			return nil, AlertWouldBlock
380		}
381
382		if epoch != cipher.epoch {
383			logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch,
384				cipher.epoch, allowOldEpoch)
385			if !allowOldEpoch {
386				return nil, AlertWouldBlock
387			}
388			cipher = c
389		}
390	}
391
392	if cipher.cipher != nil {
393		logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment)
394		pt, _, err = r.decrypt(pt, seq)
395		if err != nil {
396			logf(logTypeIO, "%s Decryption failed", r.label)
397			return nil, err
398		}
399	}
400	pt.epoch = cipher.epoch
401
402	// Check that plaintext length is not too long
403	if len(pt.fragment) > maxFragmentLen {
404		return nil, fmt.Errorf("tls.record: Plaintext size too big")
405	}
406
407	logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment)
408
409	r.cachedRecord = pt
410	cipher.incrementSequenceNumber()
411	return pt, nil
412}
413
414func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error {
415	return r.writeRecordWithPadding(pt, r.cipher, 0)
416}
417
418func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error {
419	return r.writeRecordWithPadding(pt, r.cipher, padLen)
420}
421
422func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error {
423	seq := cipher.combineSeq(r.datagram)
424	if cipher.cipher != nil {
425		logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
426		pt = r.encrypt(cipher, seq, pt, padLen)
427	} else if padLen > 0 {
428		return fmt.Errorf("tls.record: Padding can only be done on encrypted records")
429	}
430
431	if len(pt.fragment) > maxFragmentLen {
432		return fmt.Errorf("tls.record: Record size too big")
433	}
434
435	length := len(pt.fragment)
436	var header []byte
437
438	if !r.datagram {
439		header = []byte{byte(pt.contentType),
440			byte(r.version >> 8), byte(r.version & 0xff),
441			byte(length >> 8), byte(length)}
442	} else {
443		header = make([]byte, 13)
444		version := dtlsConvertVersion(r.version)
445		copy(header, []byte{byte(pt.contentType),
446			byte(version >> 8), byte(version & 0xff),
447		})
448		encodeUint(seq, 8, header[3:])
449		encodeUint(uint64(length), 2, header[11:])
450	}
451	record := append(header, pt.fragment...)
452
453	logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment)
454
455	cipher.incrementSequenceNumber()
456	_, err := r.conn.Write(record)
457	return err
458}
459