1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// DTLS implementation.
6//
7// NOTE: This is a not even a remotely production-quality DTLS
8// implementation. It is the bare minimum necessary to be able to
9// achieve coverage on BoringSSL's implementation. Of note is that
10// this implementation assumes the underlying net.PacketConn is not
11// only reliable but also ordered. BoringSSL will be expected to deal
12// with simulated loss, but there is no point in forcing the test
13// driver to.
14
15package runner
16
17import (
18	"bytes"
19	"errors"
20	"fmt"
21	"io"
22	"math/rand"
23	"net"
24)
25
26func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
27	recordHeaderLen := dtlsRecordHeaderLen
28
29	if c.rawInput == nil {
30		c.rawInput = c.in.newBlock()
31	}
32	b := c.rawInput
33
34	// Read a new packet only if the current one is empty.
35	var newPacket bool
36	if len(b.data) == 0 {
37		// Pick some absurdly large buffer size.
38		b.resize(maxCiphertext + recordHeaderLen)
39		n, err := c.conn.Read(c.rawInput.data)
40		if err != nil {
41			return 0, nil, err
42		}
43		if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
44			return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
45		}
46		c.rawInput.resize(n)
47		newPacket = true
48	}
49
50	// Read out one record.
51	//
52	// A real DTLS implementation should be tolerant of errors,
53	// but this is test code. We should not be tolerant of our
54	// peer sending garbage.
55	if len(b.data) < recordHeaderLen {
56		return 0, nil, errors.New("dtls: failed to read record header")
57	}
58	typ := recordType(b.data[0])
59	vers := uint16(b.data[1])<<8 | uint16(b.data[2])
60	// Alerts sent near version negotiation do not have a well-defined
61	// record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer
62	// version is irrelevant.)
63	if typ != recordTypeAlert {
64		if c.haveVers {
65			if vers != c.wireVersion {
66				c.sendAlert(alertProtocolVersion)
67				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
68			}
69		} else {
70			// Pre-version-negotiation alerts may be sent with any version.
71			if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
72				c.sendAlert(alertProtocolVersion)
73				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
74			}
75		}
76	}
77	epoch := b.data[3:5]
78	seq := b.data[5:11]
79	// For test purposes, require the sequence number be monotonically
80	// increasing, so c.in includes the minimum next sequence number. Gaps
81	// may occur if packets failed to be sent out. A real implementation
82	// would maintain a replay window and such.
83	if !bytes.Equal(epoch, c.in.seq[:2]) {
84		c.sendAlert(alertIllegalParameter)
85		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
86	}
87	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
88		c.sendAlert(alertIllegalParameter)
89		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
90	}
91	copy(c.in.seq[2:], seq)
92	n := int(b.data[11])<<8 | int(b.data[12])
93	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
94		c.sendAlert(alertRecordOverflow)
95		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
96	}
97
98	// Process message.
99	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
100	ok, off, _, alertValue := c.in.decrypt(b)
101	if !ok {
102		// A real DTLS implementation would silently ignore bad records,
103		// but we want to notice errors from the implementation under
104		// test.
105		return 0, nil, c.in.setErrorLocked(c.sendAlert(alertValue))
106	}
107	b.off = off
108
109	// TODO(nharper): Once DTLS 1.3 is defined, handle the extra
110	// parameter from decrypt.
111
112	// Require that ChangeCipherSpec always share a packet with either the
113	// previous or next handshake message.
114	if newPacket && typ == recordTypeChangeCipherSpec && c.rawInput == nil {
115		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: ChangeCipherSpec not packed together with Finished"))
116	}
117
118	return typ, b, nil
119}
120
121func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
122	fragment := make([]byte, 0, 12+fragLen)
123	fragment = append(fragment, header...)
124	fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
125	fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
126	fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
127	fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
128	return fragment
129}
130
131func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
132	// Only handshake messages are fragmented.
133	if typ != recordTypeHandshake {
134		reorder := typ == recordTypeChangeCipherSpec && c.config.Bugs.ReorderChangeCipherSpec
135
136		// Flush pending handshake messages before encrypting a new record.
137		if !reorder {
138			err = c.dtlsPackHandshake()
139			if err != nil {
140				return
141			}
142		}
143
144		if typ == recordTypeApplicationData && len(data) > 1 && c.config.Bugs.SplitAndPackAppData {
145			_, err = c.dtlsPackRecord(typ, data[:len(data)/2], false)
146			if err != nil {
147				return
148			}
149			_, err = c.dtlsPackRecord(typ, data[len(data)/2:], true)
150			if err != nil {
151				return
152			}
153			n = len(data)
154		} else {
155			n, err = c.dtlsPackRecord(typ, data, false)
156			if err != nil {
157				return
158			}
159		}
160
161		if reorder {
162			err = c.dtlsPackHandshake()
163			if err != nil {
164				return
165			}
166		}
167
168		if typ == recordTypeChangeCipherSpec {
169			err = c.out.changeCipherSpec(c.config)
170			if err != nil {
171				return n, c.sendAlertLocked(alertLevelError, err.(alert))
172			}
173		} else {
174			// ChangeCipherSpec is part of the handshake and not
175			// flushed until dtlsFlushPacket.
176			err = c.dtlsFlushPacket()
177			if err != nil {
178				return
179			}
180		}
181		return
182	}
183
184	if c.out.cipher == nil && c.config.Bugs.StrayChangeCipherSpec {
185		_, err = c.dtlsPackRecord(recordTypeChangeCipherSpec, []byte{1}, false)
186		if err != nil {
187			return
188		}
189	}
190
191	maxLen := c.config.Bugs.MaxHandshakeRecordLength
192	if maxLen <= 0 {
193		maxLen = 1024
194	}
195
196	// Handshake messages have to be modified to include fragment
197	// offset and length and with the header replicated. Save the
198	// TLS header here.
199	//
200	// TODO(davidben): This assumes that data contains exactly one
201	// handshake message. This is incompatible with
202	// FragmentAcrossChangeCipherSpec. (Which is unfortunate
203	// because OpenSSL's DTLS implementation will probably accept
204	// such fragmentation and could do with a fix + tests.)
205	header := data[:4]
206	data = data[4:]
207
208	isFinished := header[0] == typeFinished
209
210	if c.config.Bugs.SendEmptyFragments {
211		c.pendingFragments = append(c.pendingFragments, c.makeFragment(header, data, 0, 0))
212		c.pendingFragments = append(c.pendingFragments, c.makeFragment(header, data, len(data), 0))
213	}
214
215	firstRun := true
216	fragOffset := 0
217	for firstRun || fragOffset < len(data) {
218		firstRun = false
219		fragLen := len(data) - fragOffset
220		if fragLen > maxLen {
221			fragLen = maxLen
222		}
223
224		fragment := c.makeFragment(header, data, fragOffset, fragLen)
225		if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
226			fragment[0]++
227		}
228		if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
229			fragment[3]++
230		}
231
232		// Buffer the fragment for later. They will be sent (and
233		// reordered) on flush.
234		c.pendingFragments = append(c.pendingFragments, fragment)
235		if c.config.Bugs.ReorderHandshakeFragments {
236			// Don't duplicate Finished to avoid the peer
237			// interpreting it as a retransmit request.
238			if !isFinished {
239				c.pendingFragments = append(c.pendingFragments, fragment)
240			}
241
242			if fragLen > (maxLen+1)/2 {
243				// Overlap each fragment by half.
244				fragLen = (maxLen + 1) / 2
245			}
246		}
247		fragOffset += fragLen
248		n += fragLen
249	}
250	shouldSendTwice := c.config.Bugs.MixCompleteMessageWithFragments
251	if isFinished {
252		shouldSendTwice = c.config.Bugs.RetransmitFinished
253	}
254	if shouldSendTwice {
255		fragment := c.makeFragment(header, data, 0, len(data))
256		c.pendingFragments = append(c.pendingFragments, fragment)
257	}
258
259	// Increment the handshake sequence number for the next
260	// handshake message.
261	c.sendHandshakeSeq++
262	return
263}
264
265// dtlsPackHandshake packs the pending handshake flight into the pending
266// record. Callers should follow up with dtlsFlushPacket to write the packets.
267func (c *Conn) dtlsPackHandshake() error {
268	// This is a test-only DTLS implementation, so there is no need to
269	// retain |c.pendingFragments| for a future retransmit.
270	var fragments [][]byte
271	fragments, c.pendingFragments = c.pendingFragments, fragments
272
273	if c.config.Bugs.ReorderHandshakeFragments {
274		perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
275		tmp := make([][]byte, len(fragments))
276		for i := range tmp {
277			tmp[i] = fragments[perm[i]]
278		}
279		fragments = tmp
280	} else if c.config.Bugs.ReverseHandshakeFragments {
281		tmp := make([][]byte, len(fragments))
282		for i := range tmp {
283			tmp[i] = fragments[len(fragments)-i-1]
284		}
285		fragments = tmp
286	}
287
288	maxRecordLen := c.config.Bugs.PackHandshakeFragments
289
290	// Pack handshake fragments into records.
291	var records [][]byte
292	for _, fragment := range fragments {
293		if n := c.config.Bugs.SplitFragments; n > 0 {
294			if len(fragment) > n {
295				records = append(records, fragment[:n])
296				records = append(records, fragment[n:])
297			} else {
298				records = append(records, fragment)
299			}
300		} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
301			records[i] = append(records[i], fragment...)
302		} else {
303			// The fragment will be appended to, so copy it.
304			records = append(records, append([]byte{}, fragment...))
305		}
306	}
307
308	// Send the records.
309	for _, record := range records {
310		_, err := c.dtlsPackRecord(recordTypeHandshake, record, false)
311		if err != nil {
312			return err
313		}
314	}
315
316	return nil
317}
318
319func (c *Conn) dtlsFlushHandshake() error {
320	if err := c.dtlsPackHandshake(); err != nil {
321		return err
322	}
323	if err := c.dtlsFlushPacket(); err != nil {
324		return err
325	}
326
327	return nil
328}
329
330// dtlsPackRecord packs a single record to the pending packet, flushing it
331// if necessary. The caller should call dtlsFlushPacket to flush the current
332// pending packet afterwards.
333func (c *Conn) dtlsPackRecord(typ recordType, data []byte, mustPack bool) (n int, err error) {
334	recordHeaderLen := dtlsRecordHeaderLen
335	maxLen := c.config.Bugs.MaxHandshakeRecordLength
336	if maxLen <= 0 {
337		maxLen = 1024
338	}
339
340	b := c.out.newBlock()
341
342	explicitIVLen := 0
343	explicitIVIsSeq := false
344
345	if cbc, ok := c.out.cipher.(cbcMode); ok {
346		// Block cipher modes have an explicit IV.
347		explicitIVLen = cbc.BlockSize()
348	} else if aead, ok := c.out.cipher.(*tlsAead); ok {
349		if aead.explicitNonce {
350			explicitIVLen = 8
351			// The AES-GCM construction in TLS has an explicit nonce so that
352			// the nonce can be random. However, the nonce is only 8 bytes
353			// which is too small for a secure, random nonce. Therefore we
354			// use the sequence number as the nonce.
355			explicitIVIsSeq = true
356		}
357	} else if _, ok := c.out.cipher.(nullCipher); !ok && c.out.cipher != nil {
358		panic("Unknown cipher")
359	}
360	b.resize(recordHeaderLen + explicitIVLen + len(data))
361	// TODO(nharper): DTLS 1.3 will likely need to set this to
362	// recordTypeApplicationData if c.out.cipher != nil.
363	b.data[0] = byte(typ)
364	vers := c.wireVersion
365	if vers == 0 {
366		// Some TLS servers fail if the record version is greater than
367		// TLS 1.0 for the initial ClientHello.
368		if c.isDTLS {
369			vers = VersionDTLS10
370		} else {
371			vers = VersionTLS10
372		}
373	}
374	b.data[1] = byte(vers >> 8)
375	b.data[2] = byte(vers)
376	// DTLS records include an explicit sequence number.
377	copy(b.data[3:11], c.out.outSeq[0:])
378	b.data[11] = byte(len(data) >> 8)
379	b.data[12] = byte(len(data))
380	if explicitIVLen > 0 {
381		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
382		if explicitIVIsSeq {
383			copy(explicitIV, c.out.outSeq[:])
384		} else {
385			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
386				return
387			}
388		}
389	}
390	copy(b.data[recordHeaderLen+explicitIVLen:], data)
391	c.out.encrypt(b, explicitIVLen, typ)
392
393	// Flush the current pending packet if necessary.
394	if !mustPack && len(b.data)+len(c.pendingPacket) > c.config.Bugs.PackHandshakeRecords {
395		err = c.dtlsFlushPacket()
396		if err != nil {
397			c.out.freeBlock(b)
398			return
399		}
400	}
401
402	// Add the record to the pending packet.
403	c.pendingPacket = append(c.pendingPacket, b.data...)
404	c.out.freeBlock(b)
405	n = len(data)
406	return
407}
408
409func (c *Conn) dtlsFlushPacket() error {
410	if len(c.pendingPacket) == 0 {
411		return nil
412	}
413	_, err := c.conn.Write(c.pendingPacket)
414	c.pendingPacket = nil
415	return err
416}
417
418func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
419	// Assemble a full handshake message.  For test purposes, this
420	// implementation assumes fragments arrive in order. It may
421	// need to be cleverer if we ever test BoringSSL's retransmit
422	// behavior.
423	for len(c.handMsg) < 4+c.handMsgLen {
424		// Get a new handshake record if the previous has been
425		// exhausted.
426		if c.hand.Len() == 0 {
427			if err := c.in.err; err != nil {
428				return nil, err
429			}
430			if err := c.readRecord(recordTypeHandshake); err != nil {
431				return nil, err
432			}
433		}
434
435		// Read the next fragment. It must fit entirely within
436		// the record.
437		if c.hand.Len() < 12 {
438			return nil, errors.New("dtls: bad handshake record")
439		}
440		header := c.hand.Next(12)
441		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
442		fragSeq := uint16(header[4])<<8 | uint16(header[5])
443		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
444		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
445
446		if c.hand.Len() < fragLen {
447			return nil, errors.New("dtls: fragment length too long")
448		}
449		fragment := c.hand.Next(fragLen)
450
451		// Check it's a fragment for the right message.
452		if fragSeq != c.recvHandshakeSeq {
453			return nil, errors.New("dtls: bad handshake sequence number")
454		}
455
456		// Check that the length is consistent.
457		if c.handMsg == nil {
458			c.handMsgLen = fragN
459			if c.handMsgLen > maxHandshake {
460				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
461			}
462			// Start with the TLS handshake header,
463			// without the DTLS bits.
464			c.handMsg = append([]byte{}, header[:4]...)
465		} else if fragN != c.handMsgLen {
466			return nil, errors.New("dtls: bad handshake length")
467		}
468
469		// Add the fragment to the pending message.
470		if 4+fragOff != len(c.handMsg) {
471			return nil, errors.New("dtls: bad fragment offset")
472		}
473		if fragOff+fragLen > c.handMsgLen {
474			return nil, errors.New("dtls: bad fragment length")
475		}
476		c.handMsg = append(c.handMsg, fragment...)
477	}
478	c.recvHandshakeSeq++
479	ret := c.handMsg
480	c.handMsg, c.handMsgLen = nil, 0
481	return ret, nil
482}
483
484// DTLSServer returns a new DTLS server side connection
485// using conn as the underlying transport.
486// The configuration config must be non-nil and must have
487// at least one certificate.
488func DTLSServer(conn net.Conn, config *Config) *Conn {
489	c := &Conn{config: config, isDTLS: true, conn: conn}
490	c.init()
491	return c
492}
493
494// DTLSClient returns a new DTLS client side connection
495// using conn as the underlying transport.
496// The config cannot be nil: users must set either ServerHostname or
497// InsecureSkipVerify in the config.
498func DTLSClient(conn net.Conn, config *Config) *Conn {
499	c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
500	c.init()
501	return c
502}
503