1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"strings"
10)
11
12type clientHelloMsg struct {
13	raw                          []byte
14	vers                         uint16
15	random                       []byte
16	sessionId                    []byte
17	cipherSuites                 []uint16
18	compressionMethods           []uint8
19	nextProtoNeg                 bool
20	serverName                   string
21	ocspStapling                 bool
22	scts                         bool
23	supportedCurves              []CurveID
24	supportedPoints              []uint8
25	ticketSupported              bool
26	sessionTicket                []uint8
27	signatureAndHashes           []signatureAndHash
28	secureRenegotiation          []byte
29	secureRenegotiationSupported bool
30	alpnProtocols                []string
31}
32
33func (m *clientHelloMsg) equal(i interface{}) bool {
34	m1, ok := i.(*clientHelloMsg)
35	if !ok {
36		return false
37	}
38
39	return bytes.Equal(m.raw, m1.raw) &&
40		m.vers == m1.vers &&
41		bytes.Equal(m.random, m1.random) &&
42		bytes.Equal(m.sessionId, m1.sessionId) &&
43		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
44		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
45		m.nextProtoNeg == m1.nextProtoNeg &&
46		m.serverName == m1.serverName &&
47		m.ocspStapling == m1.ocspStapling &&
48		m.scts == m1.scts &&
49		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
50		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
51		m.ticketSupported == m1.ticketSupported &&
52		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
53		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
54		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
55		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
56		eqStrings(m.alpnProtocols, m1.alpnProtocols)
57}
58
59func (m *clientHelloMsg) marshal() []byte {
60	if m.raw != nil {
61		return m.raw
62	}
63
64	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
65	numExtensions := 0
66	extensionsLength := 0
67	if m.nextProtoNeg {
68		numExtensions++
69	}
70	if m.ocspStapling {
71		extensionsLength += 1 + 2 + 2
72		numExtensions++
73	}
74	if len(m.serverName) > 0 {
75		extensionsLength += 5 + len(m.serverName)
76		numExtensions++
77	}
78	if len(m.supportedCurves) > 0 {
79		extensionsLength += 2 + 2*len(m.supportedCurves)
80		numExtensions++
81	}
82	if len(m.supportedPoints) > 0 {
83		extensionsLength += 1 + len(m.supportedPoints)
84		numExtensions++
85	}
86	if m.ticketSupported {
87		extensionsLength += len(m.sessionTicket)
88		numExtensions++
89	}
90	if len(m.signatureAndHashes) > 0 {
91		extensionsLength += 2 + 2*len(m.signatureAndHashes)
92		numExtensions++
93	}
94	if m.secureRenegotiationSupported {
95		extensionsLength += 1 + len(m.secureRenegotiation)
96		numExtensions++
97	}
98	if len(m.alpnProtocols) > 0 {
99		extensionsLength += 2
100		for _, s := range m.alpnProtocols {
101			if l := len(s); l == 0 || l > 255 {
102				panic("invalid ALPN protocol")
103			}
104			extensionsLength++
105			extensionsLength += len(s)
106		}
107		numExtensions++
108	}
109	if m.scts {
110		numExtensions++
111	}
112	if numExtensions > 0 {
113		extensionsLength += 4 * numExtensions
114		length += 2 + extensionsLength
115	}
116
117	x := make([]byte, 4+length)
118	x[0] = typeClientHello
119	x[1] = uint8(length >> 16)
120	x[2] = uint8(length >> 8)
121	x[3] = uint8(length)
122	x[4] = uint8(m.vers >> 8)
123	x[5] = uint8(m.vers)
124	copy(x[6:38], m.random)
125	x[38] = uint8(len(m.sessionId))
126	copy(x[39:39+len(m.sessionId)], m.sessionId)
127	y := x[39+len(m.sessionId):]
128	y[0] = uint8(len(m.cipherSuites) >> 7)
129	y[1] = uint8(len(m.cipherSuites) << 1)
130	for i, suite := range m.cipherSuites {
131		y[2+i*2] = uint8(suite >> 8)
132		y[3+i*2] = uint8(suite)
133	}
134	z := y[2+len(m.cipherSuites)*2:]
135	z[0] = uint8(len(m.compressionMethods))
136	copy(z[1:], m.compressionMethods)
137
138	z = z[1+len(m.compressionMethods):]
139	if numExtensions > 0 {
140		z[0] = byte(extensionsLength >> 8)
141		z[1] = byte(extensionsLength)
142		z = z[2:]
143	}
144	if m.nextProtoNeg {
145		z[0] = byte(extensionNextProtoNeg >> 8)
146		z[1] = byte(extensionNextProtoNeg & 0xff)
147		// The length is always 0
148		z = z[4:]
149	}
150	if len(m.serverName) > 0 {
151		z[0] = byte(extensionServerName >> 8)
152		z[1] = byte(extensionServerName & 0xff)
153		l := len(m.serverName) + 5
154		z[2] = byte(l >> 8)
155		z[3] = byte(l)
156		z = z[4:]
157
158		// RFC 3546, section 3.1
159		//
160		// struct {
161		//     NameType name_type;
162		//     select (name_type) {
163		//         case host_name: HostName;
164		//     } name;
165		// } ServerName;
166		//
167		// enum {
168		//     host_name(0), (255)
169		// } NameType;
170		//
171		// opaque HostName<1..2^16-1>;
172		//
173		// struct {
174		//     ServerName server_name_list<1..2^16-1>
175		// } ServerNameList;
176
177		z[0] = byte((len(m.serverName) + 3) >> 8)
178		z[1] = byte(len(m.serverName) + 3)
179		z[3] = byte(len(m.serverName) >> 8)
180		z[4] = byte(len(m.serverName))
181		copy(z[5:], []byte(m.serverName))
182		z = z[l:]
183	}
184	if m.ocspStapling {
185		// RFC 4366, section 3.6
186		z[0] = byte(extensionStatusRequest >> 8)
187		z[1] = byte(extensionStatusRequest)
188		z[2] = 0
189		z[3] = 5
190		z[4] = 1 // OCSP type
191		// Two zero valued uint16s for the two lengths.
192		z = z[9:]
193	}
194	if len(m.supportedCurves) > 0 {
195		// http://tools.ietf.org/html/rfc4492#section-5.5.1
196		z[0] = byte(extensionSupportedCurves >> 8)
197		z[1] = byte(extensionSupportedCurves)
198		l := 2 + 2*len(m.supportedCurves)
199		z[2] = byte(l >> 8)
200		z[3] = byte(l)
201		l -= 2
202		z[4] = byte(l >> 8)
203		z[5] = byte(l)
204		z = z[6:]
205		for _, curve := range m.supportedCurves {
206			z[0] = byte(curve >> 8)
207			z[1] = byte(curve)
208			z = z[2:]
209		}
210	}
211	if len(m.supportedPoints) > 0 {
212		// http://tools.ietf.org/html/rfc4492#section-5.5.2
213		z[0] = byte(extensionSupportedPoints >> 8)
214		z[1] = byte(extensionSupportedPoints)
215		l := 1 + len(m.supportedPoints)
216		z[2] = byte(l >> 8)
217		z[3] = byte(l)
218		l--
219		z[4] = byte(l)
220		z = z[5:]
221		for _, pointFormat := range m.supportedPoints {
222			z[0] = pointFormat
223			z = z[1:]
224		}
225	}
226	if m.ticketSupported {
227		// http://tools.ietf.org/html/rfc5077#section-3.2
228		z[0] = byte(extensionSessionTicket >> 8)
229		z[1] = byte(extensionSessionTicket)
230		l := len(m.sessionTicket)
231		z[2] = byte(l >> 8)
232		z[3] = byte(l)
233		z = z[4:]
234		copy(z, m.sessionTicket)
235		z = z[len(m.sessionTicket):]
236	}
237	if len(m.signatureAndHashes) > 0 {
238		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
239		z[0] = byte(extensionSignatureAlgorithms >> 8)
240		z[1] = byte(extensionSignatureAlgorithms)
241		l := 2 + 2*len(m.signatureAndHashes)
242		z[2] = byte(l >> 8)
243		z[3] = byte(l)
244		z = z[4:]
245
246		l -= 2
247		z[0] = byte(l >> 8)
248		z[1] = byte(l)
249		z = z[2:]
250		for _, sigAndHash := range m.signatureAndHashes {
251			z[0] = sigAndHash.hash
252			z[1] = sigAndHash.signature
253			z = z[2:]
254		}
255	}
256	if m.secureRenegotiationSupported {
257		z[0] = byte(extensionRenegotiationInfo >> 8)
258		z[1] = byte(extensionRenegotiationInfo & 0xff)
259		z[2] = 0
260		z[3] = byte(len(m.secureRenegotiation) + 1)
261		z[4] = byte(len(m.secureRenegotiation))
262		z = z[5:]
263		copy(z, m.secureRenegotiation)
264		z = z[len(m.secureRenegotiation):]
265	}
266	if len(m.alpnProtocols) > 0 {
267		z[0] = byte(extensionALPN >> 8)
268		z[1] = byte(extensionALPN & 0xff)
269		lengths := z[2:]
270		z = z[6:]
271
272		stringsLength := 0
273		for _, s := range m.alpnProtocols {
274			l := len(s)
275			z[0] = byte(l)
276			copy(z[1:], s)
277			z = z[1+l:]
278			stringsLength += 1 + l
279		}
280
281		lengths[2] = byte(stringsLength >> 8)
282		lengths[3] = byte(stringsLength)
283		stringsLength += 2
284		lengths[0] = byte(stringsLength >> 8)
285		lengths[1] = byte(stringsLength)
286	}
287	if m.scts {
288		// https://tools.ietf.org/html/rfc6962#section-3.3.1
289		z[0] = byte(extensionSCT >> 8)
290		z[1] = byte(extensionSCT)
291		// zero uint16 for the zero-length extension_data
292		z = z[4:]
293	}
294
295	m.raw = x
296
297	return x
298}
299
300func (m *clientHelloMsg) unmarshal(data []byte) bool {
301	if len(data) < 42 {
302		return false
303	}
304	m.raw = data
305	m.vers = uint16(data[4])<<8 | uint16(data[5])
306	m.random = data[6:38]
307	sessionIdLen := int(data[38])
308	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
309		return false
310	}
311	m.sessionId = data[39 : 39+sessionIdLen]
312	data = data[39+sessionIdLen:]
313	if len(data) < 2 {
314		return false
315	}
316	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
317	// they are uint16s, the number must be even.
318	cipherSuiteLen := int(data[0])<<8 | int(data[1])
319	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
320		return false
321	}
322	numCipherSuites := cipherSuiteLen / 2
323	m.cipherSuites = make([]uint16, numCipherSuites)
324	for i := 0; i < numCipherSuites; i++ {
325		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
326		if m.cipherSuites[i] == scsvRenegotiation {
327			m.secureRenegotiationSupported = true
328		}
329	}
330	data = data[2+cipherSuiteLen:]
331	if len(data) < 1 {
332		return false
333	}
334	compressionMethodsLen := int(data[0])
335	if len(data) < 1+compressionMethodsLen {
336		return false
337	}
338	m.compressionMethods = data[1 : 1+compressionMethodsLen]
339
340	data = data[1+compressionMethodsLen:]
341
342	m.nextProtoNeg = false
343	m.serverName = ""
344	m.ocspStapling = false
345	m.ticketSupported = false
346	m.sessionTicket = nil
347	m.signatureAndHashes = nil
348	m.alpnProtocols = nil
349	m.scts = false
350
351	if len(data) == 0 {
352		// ClientHello is optionally followed by extension data
353		return true
354	}
355	if len(data) < 2 {
356		return false
357	}
358
359	extensionsLength := int(data[0])<<8 | int(data[1])
360	data = data[2:]
361	if extensionsLength != len(data) {
362		return false
363	}
364
365	for len(data) != 0 {
366		if len(data) < 4 {
367			return false
368		}
369		extension := uint16(data[0])<<8 | uint16(data[1])
370		length := int(data[2])<<8 | int(data[3])
371		data = data[4:]
372		if len(data) < length {
373			return false
374		}
375
376		switch extension {
377		case extensionServerName:
378			d := data[:length]
379			if len(d) < 2 {
380				return false
381			}
382			namesLen := int(d[0])<<8 | int(d[1])
383			d = d[2:]
384			if len(d) != namesLen {
385				return false
386			}
387			for len(d) > 0 {
388				if len(d) < 3 {
389					return false
390				}
391				nameType := d[0]
392				nameLen := int(d[1])<<8 | int(d[2])
393				d = d[3:]
394				if len(d) < nameLen {
395					return false
396				}
397				if nameType == 0 {
398					m.serverName = string(d[:nameLen])
399					// An SNI value may not include a
400					// trailing dot. See
401					// https://tools.ietf.org/html/rfc6066#section-3.
402					if strings.HasSuffix(m.serverName, ".") {
403						return false
404					}
405					break
406				}
407				d = d[nameLen:]
408			}
409		case extensionNextProtoNeg:
410			if length > 0 {
411				return false
412			}
413			m.nextProtoNeg = true
414		case extensionStatusRequest:
415			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
416		case extensionSupportedCurves:
417			// http://tools.ietf.org/html/rfc4492#section-5.5.1
418			if length < 2 {
419				return false
420			}
421			l := int(data[0])<<8 | int(data[1])
422			if l%2 == 1 || length != l+2 {
423				return false
424			}
425			numCurves := l / 2
426			m.supportedCurves = make([]CurveID, numCurves)
427			d := data[2:]
428			for i := 0; i < numCurves; i++ {
429				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
430				d = d[2:]
431			}
432		case extensionSupportedPoints:
433			// http://tools.ietf.org/html/rfc4492#section-5.5.2
434			if length < 1 {
435				return false
436			}
437			l := int(data[0])
438			if length != l+1 {
439				return false
440			}
441			m.supportedPoints = make([]uint8, l)
442			copy(m.supportedPoints, data[1:])
443		case extensionSessionTicket:
444			// http://tools.ietf.org/html/rfc5077#section-3.2
445			m.ticketSupported = true
446			m.sessionTicket = data[:length]
447		case extensionSignatureAlgorithms:
448			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
449			if length < 2 || length&1 != 0 {
450				return false
451			}
452			l := int(data[0])<<8 | int(data[1])
453			if l != length-2 {
454				return false
455			}
456			n := l / 2
457			d := data[2:]
458			m.signatureAndHashes = make([]signatureAndHash, n)
459			for i := range m.signatureAndHashes {
460				m.signatureAndHashes[i].hash = d[0]
461				m.signatureAndHashes[i].signature = d[1]
462				d = d[2:]
463			}
464		case extensionRenegotiationInfo:
465			if length == 0 {
466				return false
467			}
468			d := data[:length]
469			l := int(d[0])
470			d = d[1:]
471			if l != len(d) {
472				return false
473			}
474
475			m.secureRenegotiation = d
476			m.secureRenegotiationSupported = true
477		case extensionALPN:
478			if length < 2 {
479				return false
480			}
481			l := int(data[0])<<8 | int(data[1])
482			if l != length-2 {
483				return false
484			}
485			d := data[2:length]
486			for len(d) != 0 {
487				stringLen := int(d[0])
488				d = d[1:]
489				if stringLen == 0 || stringLen > len(d) {
490					return false
491				}
492				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
493				d = d[stringLen:]
494			}
495		case extensionSCT:
496			m.scts = true
497			if length != 0 {
498				return false
499			}
500		}
501		data = data[length:]
502	}
503
504	return true
505}
506
507type serverHelloMsg struct {
508	raw                          []byte
509	vers                         uint16
510	random                       []byte
511	sessionId                    []byte
512	cipherSuite                  uint16
513	compressionMethod            uint8
514	nextProtoNeg                 bool
515	nextProtos                   []string
516	ocspStapling                 bool
517	scts                         [][]byte
518	ticketSupported              bool
519	secureRenegotiation          []byte
520	secureRenegotiationSupported bool
521	alpnProtocol                 string
522}
523
524func (m *serverHelloMsg) equal(i interface{}) bool {
525	m1, ok := i.(*serverHelloMsg)
526	if !ok {
527		return false
528	}
529
530	if len(m.scts) != len(m1.scts) {
531		return false
532	}
533	for i, sct := range m.scts {
534		if !bytes.Equal(sct, m1.scts[i]) {
535			return false
536		}
537	}
538
539	return bytes.Equal(m.raw, m1.raw) &&
540		m.vers == m1.vers &&
541		bytes.Equal(m.random, m1.random) &&
542		bytes.Equal(m.sessionId, m1.sessionId) &&
543		m.cipherSuite == m1.cipherSuite &&
544		m.compressionMethod == m1.compressionMethod &&
545		m.nextProtoNeg == m1.nextProtoNeg &&
546		eqStrings(m.nextProtos, m1.nextProtos) &&
547		m.ocspStapling == m1.ocspStapling &&
548		m.ticketSupported == m1.ticketSupported &&
549		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
550		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
551		m.alpnProtocol == m1.alpnProtocol
552}
553
554func (m *serverHelloMsg) marshal() []byte {
555	if m.raw != nil {
556		return m.raw
557	}
558
559	length := 38 + len(m.sessionId)
560	numExtensions := 0
561	extensionsLength := 0
562
563	nextProtoLen := 0
564	if m.nextProtoNeg {
565		numExtensions++
566		for _, v := range m.nextProtos {
567			nextProtoLen += len(v)
568		}
569		nextProtoLen += len(m.nextProtos)
570		extensionsLength += nextProtoLen
571	}
572	if m.ocspStapling {
573		numExtensions++
574	}
575	if m.ticketSupported {
576		numExtensions++
577	}
578	if m.secureRenegotiationSupported {
579		extensionsLength += 1 + len(m.secureRenegotiation)
580		numExtensions++
581	}
582	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
583		if alpnLen >= 256 {
584			panic("invalid ALPN protocol")
585		}
586		extensionsLength += 2 + 1 + alpnLen
587		numExtensions++
588	}
589	sctLen := 0
590	if len(m.scts) > 0 {
591		for _, sct := range m.scts {
592			sctLen += len(sct) + 2
593		}
594		extensionsLength += 2 + sctLen
595		numExtensions++
596	}
597
598	if numExtensions > 0 {
599		extensionsLength += 4 * numExtensions
600		length += 2 + extensionsLength
601	}
602
603	x := make([]byte, 4+length)
604	x[0] = typeServerHello
605	x[1] = uint8(length >> 16)
606	x[2] = uint8(length >> 8)
607	x[3] = uint8(length)
608	x[4] = uint8(m.vers >> 8)
609	x[5] = uint8(m.vers)
610	copy(x[6:38], m.random)
611	x[38] = uint8(len(m.sessionId))
612	copy(x[39:39+len(m.sessionId)], m.sessionId)
613	z := x[39+len(m.sessionId):]
614	z[0] = uint8(m.cipherSuite >> 8)
615	z[1] = uint8(m.cipherSuite)
616	z[2] = m.compressionMethod
617
618	z = z[3:]
619	if numExtensions > 0 {
620		z[0] = byte(extensionsLength >> 8)
621		z[1] = byte(extensionsLength)
622		z = z[2:]
623	}
624	if m.nextProtoNeg {
625		z[0] = byte(extensionNextProtoNeg >> 8)
626		z[1] = byte(extensionNextProtoNeg & 0xff)
627		z[2] = byte(nextProtoLen >> 8)
628		z[3] = byte(nextProtoLen)
629		z = z[4:]
630
631		for _, v := range m.nextProtos {
632			l := len(v)
633			if l > 255 {
634				l = 255
635			}
636			z[0] = byte(l)
637			copy(z[1:], []byte(v[0:l]))
638			z = z[1+l:]
639		}
640	}
641	if m.ocspStapling {
642		z[0] = byte(extensionStatusRequest >> 8)
643		z[1] = byte(extensionStatusRequest)
644		z = z[4:]
645	}
646	if m.ticketSupported {
647		z[0] = byte(extensionSessionTicket >> 8)
648		z[1] = byte(extensionSessionTicket)
649		z = z[4:]
650	}
651	if m.secureRenegotiationSupported {
652		z[0] = byte(extensionRenegotiationInfo >> 8)
653		z[1] = byte(extensionRenegotiationInfo & 0xff)
654		z[2] = 0
655		z[3] = byte(len(m.secureRenegotiation) + 1)
656		z[4] = byte(len(m.secureRenegotiation))
657		z = z[5:]
658		copy(z, m.secureRenegotiation)
659		z = z[len(m.secureRenegotiation):]
660	}
661	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
662		z[0] = byte(extensionALPN >> 8)
663		z[1] = byte(extensionALPN & 0xff)
664		l := 2 + 1 + alpnLen
665		z[2] = byte(l >> 8)
666		z[3] = byte(l)
667		l -= 2
668		z[4] = byte(l >> 8)
669		z[5] = byte(l)
670		l -= 1
671		z[6] = byte(l)
672		copy(z[7:], []byte(m.alpnProtocol))
673		z = z[7+alpnLen:]
674	}
675	if sctLen > 0 {
676		z[0] = byte(extensionSCT >> 8)
677		z[1] = byte(extensionSCT)
678		l := sctLen + 2
679		z[2] = byte(l >> 8)
680		z[3] = byte(l)
681		z[4] = byte(sctLen >> 8)
682		z[5] = byte(sctLen)
683
684		z = z[6:]
685		for _, sct := range m.scts {
686			z[0] = byte(len(sct) >> 8)
687			z[1] = byte(len(sct))
688			copy(z[2:], sct)
689			z = z[len(sct)+2:]
690		}
691	}
692
693	m.raw = x
694
695	return x
696}
697
698func (m *serverHelloMsg) unmarshal(data []byte) bool {
699	if len(data) < 42 {
700		return false
701	}
702	m.raw = data
703	m.vers = uint16(data[4])<<8 | uint16(data[5])
704	m.random = data[6:38]
705	sessionIdLen := int(data[38])
706	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
707		return false
708	}
709	m.sessionId = data[39 : 39+sessionIdLen]
710	data = data[39+sessionIdLen:]
711	if len(data) < 3 {
712		return false
713	}
714	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
715	m.compressionMethod = data[2]
716	data = data[3:]
717
718	m.nextProtoNeg = false
719	m.nextProtos = nil
720	m.ocspStapling = false
721	m.scts = nil
722	m.ticketSupported = false
723	m.alpnProtocol = ""
724
725	if len(data) == 0 {
726		// ServerHello is optionally followed by extension data
727		return true
728	}
729	if len(data) < 2 {
730		return false
731	}
732
733	extensionsLength := int(data[0])<<8 | int(data[1])
734	data = data[2:]
735	if len(data) != extensionsLength {
736		return false
737	}
738
739	for len(data) != 0 {
740		if len(data) < 4 {
741			return false
742		}
743		extension := uint16(data[0])<<8 | uint16(data[1])
744		length := int(data[2])<<8 | int(data[3])
745		data = data[4:]
746		if len(data) < length {
747			return false
748		}
749
750		switch extension {
751		case extensionNextProtoNeg:
752			m.nextProtoNeg = true
753			d := data[:length]
754			for len(d) > 0 {
755				l := int(d[0])
756				d = d[1:]
757				if l == 0 || l > len(d) {
758					return false
759				}
760				m.nextProtos = append(m.nextProtos, string(d[:l]))
761				d = d[l:]
762			}
763		case extensionStatusRequest:
764			if length > 0 {
765				return false
766			}
767			m.ocspStapling = true
768		case extensionSessionTicket:
769			if length > 0 {
770				return false
771			}
772			m.ticketSupported = true
773		case extensionRenegotiationInfo:
774			if length == 0 {
775				return false
776			}
777			d := data[:length]
778			l := int(d[0])
779			d = d[1:]
780			if l != len(d) {
781				return false
782			}
783
784			m.secureRenegotiation = d
785			m.secureRenegotiationSupported = true
786		case extensionALPN:
787			d := data[:length]
788			if len(d) < 3 {
789				return false
790			}
791			l := int(d[0])<<8 | int(d[1])
792			if l != len(d)-2 {
793				return false
794			}
795			d = d[2:]
796			l = int(d[0])
797			if l != len(d)-1 {
798				return false
799			}
800			d = d[1:]
801			if len(d) == 0 {
802				// ALPN protocols must not be empty.
803				return false
804			}
805			m.alpnProtocol = string(d)
806		case extensionSCT:
807			d := data[:length]
808
809			if len(d) < 2 {
810				return false
811			}
812			l := int(d[0])<<8 | int(d[1])
813			d = d[2:]
814			if len(d) != l || l == 0 {
815				return false
816			}
817
818			m.scts = make([][]byte, 0, 3)
819			for len(d) != 0 {
820				if len(d) < 2 {
821					return false
822				}
823				sctLen := int(d[0])<<8 | int(d[1])
824				d = d[2:]
825				if sctLen == 0 || len(d) < sctLen {
826					return false
827				}
828				m.scts = append(m.scts, d[:sctLen])
829				d = d[sctLen:]
830			}
831		}
832		data = data[length:]
833	}
834
835	return true
836}
837
838type certificateMsg struct {
839	raw          []byte
840	certificates [][]byte
841}
842
843func (m *certificateMsg) equal(i interface{}) bool {
844	m1, ok := i.(*certificateMsg)
845	if !ok {
846		return false
847	}
848
849	return bytes.Equal(m.raw, m1.raw) &&
850		eqByteSlices(m.certificates, m1.certificates)
851}
852
853func (m *certificateMsg) marshal() (x []byte) {
854	if m.raw != nil {
855		return m.raw
856	}
857
858	var i int
859	for _, slice := range m.certificates {
860		i += len(slice)
861	}
862
863	length := 3 + 3*len(m.certificates) + i
864	x = make([]byte, 4+length)
865	x[0] = typeCertificate
866	x[1] = uint8(length >> 16)
867	x[2] = uint8(length >> 8)
868	x[3] = uint8(length)
869
870	certificateOctets := length - 3
871	x[4] = uint8(certificateOctets >> 16)
872	x[5] = uint8(certificateOctets >> 8)
873	x[6] = uint8(certificateOctets)
874
875	y := x[7:]
876	for _, slice := range m.certificates {
877		y[0] = uint8(len(slice) >> 16)
878		y[1] = uint8(len(slice) >> 8)
879		y[2] = uint8(len(slice))
880		copy(y[3:], slice)
881		y = y[3+len(slice):]
882	}
883
884	m.raw = x
885	return
886}
887
888func (m *certificateMsg) unmarshal(data []byte) bool {
889	if len(data) < 7 {
890		return false
891	}
892
893	m.raw = data
894	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
895	if uint32(len(data)) != certsLen+7 {
896		return false
897	}
898
899	numCerts := 0
900	d := data[7:]
901	for certsLen > 0 {
902		if len(d) < 4 {
903			return false
904		}
905		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
906		if uint32(len(d)) < 3+certLen {
907			return false
908		}
909		d = d[3+certLen:]
910		certsLen -= 3 + certLen
911		numCerts++
912	}
913
914	m.certificates = make([][]byte, numCerts)
915	d = data[7:]
916	for i := 0; i < numCerts; i++ {
917		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
918		m.certificates[i] = d[3 : 3+certLen]
919		d = d[3+certLen:]
920	}
921
922	return true
923}
924
925type serverKeyExchangeMsg struct {
926	raw []byte
927	key []byte
928}
929
930func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
931	m1, ok := i.(*serverKeyExchangeMsg)
932	if !ok {
933		return false
934	}
935
936	return bytes.Equal(m.raw, m1.raw) &&
937		bytes.Equal(m.key, m1.key)
938}
939
940func (m *serverKeyExchangeMsg) marshal() []byte {
941	if m.raw != nil {
942		return m.raw
943	}
944	length := len(m.key)
945	x := make([]byte, length+4)
946	x[0] = typeServerKeyExchange
947	x[1] = uint8(length >> 16)
948	x[2] = uint8(length >> 8)
949	x[3] = uint8(length)
950	copy(x[4:], m.key)
951
952	m.raw = x
953	return x
954}
955
956func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
957	m.raw = data
958	if len(data) < 4 {
959		return false
960	}
961	m.key = data[4:]
962	return true
963}
964
965type certificateStatusMsg struct {
966	raw        []byte
967	statusType uint8
968	response   []byte
969}
970
971func (m *certificateStatusMsg) equal(i interface{}) bool {
972	m1, ok := i.(*certificateStatusMsg)
973	if !ok {
974		return false
975	}
976
977	return bytes.Equal(m.raw, m1.raw) &&
978		m.statusType == m1.statusType &&
979		bytes.Equal(m.response, m1.response)
980}
981
982func (m *certificateStatusMsg) marshal() []byte {
983	if m.raw != nil {
984		return m.raw
985	}
986
987	var x []byte
988	if m.statusType == statusTypeOCSP {
989		x = make([]byte, 4+4+len(m.response))
990		x[0] = typeCertificateStatus
991		l := len(m.response) + 4
992		x[1] = byte(l >> 16)
993		x[2] = byte(l >> 8)
994		x[3] = byte(l)
995		x[4] = statusTypeOCSP
996
997		l -= 4
998		x[5] = byte(l >> 16)
999		x[6] = byte(l >> 8)
1000		x[7] = byte(l)
1001		copy(x[8:], m.response)
1002	} else {
1003		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
1004	}
1005
1006	m.raw = x
1007	return x
1008}
1009
1010func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1011	m.raw = data
1012	if len(data) < 5 {
1013		return false
1014	}
1015	m.statusType = data[4]
1016
1017	m.response = nil
1018	if m.statusType == statusTypeOCSP {
1019		if len(data) < 8 {
1020			return false
1021		}
1022		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
1023		if uint32(len(data)) != 4+4+respLen {
1024			return false
1025		}
1026		m.response = data[8:]
1027	}
1028	return true
1029}
1030
1031type serverHelloDoneMsg struct{}
1032
1033func (m *serverHelloDoneMsg) equal(i interface{}) bool {
1034	_, ok := i.(*serverHelloDoneMsg)
1035	return ok
1036}
1037
1038func (m *serverHelloDoneMsg) marshal() []byte {
1039	x := make([]byte, 4)
1040	x[0] = typeServerHelloDone
1041	return x
1042}
1043
1044func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1045	return len(data) == 4
1046}
1047
1048type clientKeyExchangeMsg struct {
1049	raw        []byte
1050	ciphertext []byte
1051}
1052
1053func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1054	m1, ok := i.(*clientKeyExchangeMsg)
1055	if !ok {
1056		return false
1057	}
1058
1059	return bytes.Equal(m.raw, m1.raw) &&
1060		bytes.Equal(m.ciphertext, m1.ciphertext)
1061}
1062
1063func (m *clientKeyExchangeMsg) marshal() []byte {
1064	if m.raw != nil {
1065		return m.raw
1066	}
1067	length := len(m.ciphertext)
1068	x := make([]byte, length+4)
1069	x[0] = typeClientKeyExchange
1070	x[1] = uint8(length >> 16)
1071	x[2] = uint8(length >> 8)
1072	x[3] = uint8(length)
1073	copy(x[4:], m.ciphertext)
1074
1075	m.raw = x
1076	return x
1077}
1078
1079func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1080	m.raw = data
1081	if len(data) < 4 {
1082		return false
1083	}
1084	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1085	if l != len(data)-4 {
1086		return false
1087	}
1088	m.ciphertext = data[4:]
1089	return true
1090}
1091
1092type finishedMsg struct {
1093	raw        []byte
1094	verifyData []byte
1095}
1096
1097func (m *finishedMsg) equal(i interface{}) bool {
1098	m1, ok := i.(*finishedMsg)
1099	if !ok {
1100		return false
1101	}
1102
1103	return bytes.Equal(m.raw, m1.raw) &&
1104		bytes.Equal(m.verifyData, m1.verifyData)
1105}
1106
1107func (m *finishedMsg) marshal() (x []byte) {
1108	if m.raw != nil {
1109		return m.raw
1110	}
1111
1112	x = make([]byte, 4+len(m.verifyData))
1113	x[0] = typeFinished
1114	x[3] = byte(len(m.verifyData))
1115	copy(x[4:], m.verifyData)
1116	m.raw = x
1117	return
1118}
1119
1120func (m *finishedMsg) unmarshal(data []byte) bool {
1121	m.raw = data
1122	if len(data) < 4 {
1123		return false
1124	}
1125	m.verifyData = data[4:]
1126	return true
1127}
1128
1129type nextProtoMsg struct {
1130	raw   []byte
1131	proto string
1132}
1133
1134func (m *nextProtoMsg) equal(i interface{}) bool {
1135	m1, ok := i.(*nextProtoMsg)
1136	if !ok {
1137		return false
1138	}
1139
1140	return bytes.Equal(m.raw, m1.raw) &&
1141		m.proto == m1.proto
1142}
1143
1144func (m *nextProtoMsg) marshal() []byte {
1145	if m.raw != nil {
1146		return m.raw
1147	}
1148	l := len(m.proto)
1149	if l > 255 {
1150		l = 255
1151	}
1152
1153	padding := 32 - (l+2)%32
1154	length := l + padding + 2
1155	x := make([]byte, length+4)
1156	x[0] = typeNextProtocol
1157	x[1] = uint8(length >> 16)
1158	x[2] = uint8(length >> 8)
1159	x[3] = uint8(length)
1160
1161	y := x[4:]
1162	y[0] = byte(l)
1163	copy(y[1:], []byte(m.proto[0:l]))
1164	y = y[1+l:]
1165	y[0] = byte(padding)
1166
1167	m.raw = x
1168
1169	return x
1170}
1171
1172func (m *nextProtoMsg) unmarshal(data []byte) bool {
1173	m.raw = data
1174
1175	if len(data) < 5 {
1176		return false
1177	}
1178	data = data[4:]
1179	protoLen := int(data[0])
1180	data = data[1:]
1181	if len(data) < protoLen {
1182		return false
1183	}
1184	m.proto = string(data[0:protoLen])
1185	data = data[protoLen:]
1186
1187	if len(data) < 1 {
1188		return false
1189	}
1190	paddingLen := int(data[0])
1191	data = data[1:]
1192	if len(data) != paddingLen {
1193		return false
1194	}
1195
1196	return true
1197}
1198
1199type certificateRequestMsg struct {
1200	raw []byte
1201	// hasSignatureAndHash indicates whether this message includes a list
1202	// of signature and hash functions. This change was introduced with TLS
1203	// 1.2.
1204	hasSignatureAndHash bool
1205
1206	certificateTypes       []byte
1207	signatureAndHashes     []signatureAndHash
1208	certificateAuthorities [][]byte
1209}
1210
1211func (m *certificateRequestMsg) equal(i interface{}) bool {
1212	m1, ok := i.(*certificateRequestMsg)
1213	if !ok {
1214		return false
1215	}
1216
1217	return bytes.Equal(m.raw, m1.raw) &&
1218		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1219		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1220		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
1221}
1222
1223func (m *certificateRequestMsg) marshal() (x []byte) {
1224	if m.raw != nil {
1225		return m.raw
1226	}
1227
1228	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
1229	length := 1 + len(m.certificateTypes) + 2
1230	casLength := 0
1231	for _, ca := range m.certificateAuthorities {
1232		casLength += 2 + len(ca)
1233	}
1234	length += casLength
1235
1236	if m.hasSignatureAndHash {
1237		length += 2 + 2*len(m.signatureAndHashes)
1238	}
1239
1240	x = make([]byte, 4+length)
1241	x[0] = typeCertificateRequest
1242	x[1] = uint8(length >> 16)
1243	x[2] = uint8(length >> 8)
1244	x[3] = uint8(length)
1245
1246	x[4] = uint8(len(m.certificateTypes))
1247
1248	copy(x[5:], m.certificateTypes)
1249	y := x[5+len(m.certificateTypes):]
1250
1251	if m.hasSignatureAndHash {
1252		n := len(m.signatureAndHashes) * 2
1253		y[0] = uint8(n >> 8)
1254		y[1] = uint8(n)
1255		y = y[2:]
1256		for _, sigAndHash := range m.signatureAndHashes {
1257			y[0] = sigAndHash.hash
1258			y[1] = sigAndHash.signature
1259			y = y[2:]
1260		}
1261	}
1262
1263	y[0] = uint8(casLength >> 8)
1264	y[1] = uint8(casLength)
1265	y = y[2:]
1266	for _, ca := range m.certificateAuthorities {
1267		y[0] = uint8(len(ca) >> 8)
1268		y[1] = uint8(len(ca))
1269		y = y[2:]
1270		copy(y, ca)
1271		y = y[len(ca):]
1272	}
1273
1274	m.raw = x
1275	return
1276}
1277
1278func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1279	m.raw = data
1280
1281	if len(data) < 5 {
1282		return false
1283	}
1284
1285	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1286	if uint32(len(data))-4 != length {
1287		return false
1288	}
1289
1290	numCertTypes := int(data[4])
1291	data = data[5:]
1292	if numCertTypes == 0 || len(data) <= numCertTypes {
1293		return false
1294	}
1295
1296	m.certificateTypes = make([]byte, numCertTypes)
1297	if copy(m.certificateTypes, data) != numCertTypes {
1298		return false
1299	}
1300
1301	data = data[numCertTypes:]
1302
1303	if m.hasSignatureAndHash {
1304		if len(data) < 2 {
1305			return false
1306		}
1307		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1308		data = data[2:]
1309		if sigAndHashLen&1 != 0 {
1310			return false
1311		}
1312		if len(data) < int(sigAndHashLen) {
1313			return false
1314		}
1315		numSigAndHash := sigAndHashLen / 2
1316		m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
1317		for i := range m.signatureAndHashes {
1318			m.signatureAndHashes[i].hash = data[0]
1319			m.signatureAndHashes[i].signature = data[1]
1320			data = data[2:]
1321		}
1322	}
1323
1324	if len(data) < 2 {
1325		return false
1326	}
1327	casLength := uint16(data[0])<<8 | uint16(data[1])
1328	data = data[2:]
1329	if len(data) < int(casLength) {
1330		return false
1331	}
1332	cas := make([]byte, casLength)
1333	copy(cas, data)
1334	data = data[casLength:]
1335
1336	m.certificateAuthorities = nil
1337	for len(cas) > 0 {
1338		if len(cas) < 2 {
1339			return false
1340		}
1341		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1342		cas = cas[2:]
1343
1344		if len(cas) < int(caLen) {
1345			return false
1346		}
1347
1348		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1349		cas = cas[caLen:]
1350	}
1351
1352	return len(data) == 0
1353}
1354
1355type certificateVerifyMsg struct {
1356	raw                 []byte
1357	hasSignatureAndHash bool
1358	signatureAndHash    signatureAndHash
1359	signature           []byte
1360}
1361
1362func (m *certificateVerifyMsg) equal(i interface{}) bool {
1363	m1, ok := i.(*certificateVerifyMsg)
1364	if !ok {
1365		return false
1366	}
1367
1368	return bytes.Equal(m.raw, m1.raw) &&
1369		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1370		m.signatureAndHash.hash == m1.signatureAndHash.hash &&
1371		m.signatureAndHash.signature == m1.signatureAndHash.signature &&
1372		bytes.Equal(m.signature, m1.signature)
1373}
1374
1375func (m *certificateVerifyMsg) marshal() (x []byte) {
1376	if m.raw != nil {
1377		return m.raw
1378	}
1379
1380	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
1381	siglength := len(m.signature)
1382	length := 2 + siglength
1383	if m.hasSignatureAndHash {
1384		length += 2
1385	}
1386	x = make([]byte, 4+length)
1387	x[0] = typeCertificateVerify
1388	x[1] = uint8(length >> 16)
1389	x[2] = uint8(length >> 8)
1390	x[3] = uint8(length)
1391	y := x[4:]
1392	if m.hasSignatureAndHash {
1393		y[0] = m.signatureAndHash.hash
1394		y[1] = m.signatureAndHash.signature
1395		y = y[2:]
1396	}
1397	y[0] = uint8(siglength >> 8)
1398	y[1] = uint8(siglength)
1399	copy(y[2:], m.signature)
1400
1401	m.raw = x
1402
1403	return
1404}
1405
1406func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1407	m.raw = data
1408
1409	if len(data) < 6 {
1410		return false
1411	}
1412
1413	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1414	if uint32(len(data))-4 != length {
1415		return false
1416	}
1417
1418	data = data[4:]
1419	if m.hasSignatureAndHash {
1420		m.signatureAndHash.hash = data[0]
1421		m.signatureAndHash.signature = data[1]
1422		data = data[2:]
1423	}
1424
1425	if len(data) < 2 {
1426		return false
1427	}
1428	siglength := int(data[0])<<8 + int(data[1])
1429	data = data[2:]
1430	if len(data) != siglength {
1431		return false
1432	}
1433
1434	m.signature = data
1435
1436	return true
1437}
1438
1439type newSessionTicketMsg struct {
1440	raw    []byte
1441	ticket []byte
1442}
1443
1444func (m *newSessionTicketMsg) equal(i interface{}) bool {
1445	m1, ok := i.(*newSessionTicketMsg)
1446	if !ok {
1447		return false
1448	}
1449
1450	return bytes.Equal(m.raw, m1.raw) &&
1451		bytes.Equal(m.ticket, m1.ticket)
1452}
1453
1454func (m *newSessionTicketMsg) marshal() (x []byte) {
1455	if m.raw != nil {
1456		return m.raw
1457	}
1458
1459	// See http://tools.ietf.org/html/rfc5077#section-3.3
1460	ticketLen := len(m.ticket)
1461	length := 2 + 4 + ticketLen
1462	x = make([]byte, 4+length)
1463	x[0] = typeNewSessionTicket
1464	x[1] = uint8(length >> 16)
1465	x[2] = uint8(length >> 8)
1466	x[3] = uint8(length)
1467	x[8] = uint8(ticketLen >> 8)
1468	x[9] = uint8(ticketLen)
1469	copy(x[10:], m.ticket)
1470
1471	m.raw = x
1472
1473	return
1474}
1475
1476func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1477	m.raw = data
1478
1479	if len(data) < 10 {
1480		return false
1481	}
1482
1483	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1484	if uint32(len(data))-4 != length {
1485		return false
1486	}
1487
1488	ticketLen := int(data[8])<<8 + int(data[9])
1489	if len(data)-10 != ticketLen {
1490		return false
1491	}
1492
1493	m.ticket = data[10:]
1494
1495	return true
1496}
1497
1498type helloRequestMsg struct {
1499}
1500
1501func (*helloRequestMsg) marshal() []byte {
1502	return []byte{typeHelloRequest, 0, 0, 0}
1503}
1504
1505func (*helloRequestMsg) unmarshal(data []byte) bool {
1506	return len(data) == 4
1507}
1508
1509func eqUint16s(x, y []uint16) bool {
1510	if len(x) != len(y) {
1511		return false
1512	}
1513	for i, v := range x {
1514		if y[i] != v {
1515			return false
1516		}
1517	}
1518	return true
1519}
1520
1521func eqCurveIDs(x, y []CurveID) bool {
1522	if len(x) != len(y) {
1523		return false
1524	}
1525	for i, v := range x {
1526		if y[i] != v {
1527			return false
1528		}
1529	}
1530	return true
1531}
1532
1533func eqStrings(x, y []string) bool {
1534	if len(x) != len(y) {
1535		return false
1536	}
1537	for i, v := range x {
1538		if y[i] != v {
1539			return false
1540		}
1541	}
1542	return true
1543}
1544
1545func eqByteSlices(x, y [][]byte) bool {
1546	if len(x) != len(y) {
1547		return false
1548	}
1549	for i, v := range x {
1550		if !bytes.Equal(v, y[i]) {
1551			return false
1552		}
1553	}
1554	return true
1555}
1556
1557func eqSignatureAndHashes(x, y []signatureAndHash) bool {
1558	if len(x) != len(y) {
1559		return false
1560	}
1561	for i, v := range x {
1562		v2 := y[i]
1563		if v.hash != v2.hash || v.signature != v2.signature {
1564			return false
1565		}
1566	}
1567	return true
1568}
1569