1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"strings"
10)
11
12type clientHelloMsg struct {
13	raw                          []byte
14	vers                         uint16
15	random                       []byte
16	sessionId                    []byte
17	cipherSuites                 []uint16
18	compressionMethods           []uint8
19	nextProtoNeg                 bool
20	serverName                   string
21	ocspStapling                 bool
22	scts                         bool
23	supportedCurves              []CurveID
24	supportedPoints              []uint8
25	ticketSupported              bool
26	sessionTicket                []uint8
27	supportedSignatureAlgorithms []SignatureScheme
28	secureRenegotiation          []byte
29	secureRenegotiationSupported bool
30	alpnProtocols                []string
31}
32
33func (m *clientHelloMsg) equal(i interface{}) bool {
34	m1, ok := i.(*clientHelloMsg)
35	if !ok {
36		return false
37	}
38
39	return bytes.Equal(m.raw, m1.raw) &&
40		m.vers == m1.vers &&
41		bytes.Equal(m.random, m1.random) &&
42		bytes.Equal(m.sessionId, m1.sessionId) &&
43		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
44		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
45		m.nextProtoNeg == m1.nextProtoNeg &&
46		m.serverName == m1.serverName &&
47		m.ocspStapling == m1.ocspStapling &&
48		m.scts == m1.scts &&
49		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
50		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
51		m.ticketSupported == m1.ticketSupported &&
52		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
53		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
54		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
55		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
56		eqStrings(m.alpnProtocols, m1.alpnProtocols)
57}
58
59func (m *clientHelloMsg) marshal() []byte {
60	if m.raw != nil {
61		return m.raw
62	}
63
64	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
65	numExtensions := 0
66	extensionsLength := 0
67	if m.nextProtoNeg {
68		numExtensions++
69	}
70	if m.ocspStapling {
71		extensionsLength += 1 + 2 + 2
72		numExtensions++
73	}
74	if len(m.serverName) > 0 {
75		extensionsLength += 5 + len(m.serverName)
76		numExtensions++
77	}
78	if len(m.supportedCurves) > 0 {
79		extensionsLength += 2 + 2*len(m.supportedCurves)
80		numExtensions++
81	}
82	if len(m.supportedPoints) > 0 {
83		extensionsLength += 1 + len(m.supportedPoints)
84		numExtensions++
85	}
86	if m.ticketSupported {
87		extensionsLength += len(m.sessionTicket)
88		numExtensions++
89	}
90	if len(m.supportedSignatureAlgorithms) > 0 {
91		extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
92		numExtensions++
93	}
94	if m.secureRenegotiationSupported {
95		extensionsLength += 1 + len(m.secureRenegotiation)
96		numExtensions++
97	}
98	if len(m.alpnProtocols) > 0 {
99		extensionsLength += 2
100		for _, s := range m.alpnProtocols {
101			if l := len(s); l == 0 || l > 255 {
102				panic("invalid ALPN protocol")
103			}
104			extensionsLength++
105			extensionsLength += len(s)
106		}
107		numExtensions++
108	}
109	if m.scts {
110		numExtensions++
111	}
112	if numExtensions > 0 {
113		extensionsLength += 4 * numExtensions
114		length += 2 + extensionsLength
115	}
116
117	x := make([]byte, 4+length)
118	x[0] = typeClientHello
119	x[1] = uint8(length >> 16)
120	x[2] = uint8(length >> 8)
121	x[3] = uint8(length)
122	x[4] = uint8(m.vers >> 8)
123	x[5] = uint8(m.vers)
124	copy(x[6:38], m.random)
125	x[38] = uint8(len(m.sessionId))
126	copy(x[39:39+len(m.sessionId)], m.sessionId)
127	y := x[39+len(m.sessionId):]
128	y[0] = uint8(len(m.cipherSuites) >> 7)
129	y[1] = uint8(len(m.cipherSuites) << 1)
130	for i, suite := range m.cipherSuites {
131		y[2+i*2] = uint8(suite >> 8)
132		y[3+i*2] = uint8(suite)
133	}
134	z := y[2+len(m.cipherSuites)*2:]
135	z[0] = uint8(len(m.compressionMethods))
136	copy(z[1:], m.compressionMethods)
137
138	z = z[1+len(m.compressionMethods):]
139	if numExtensions > 0 {
140		z[0] = byte(extensionsLength >> 8)
141		z[1] = byte(extensionsLength)
142		z = z[2:]
143	}
144	if m.nextProtoNeg {
145		z[0] = byte(extensionNextProtoNeg >> 8)
146		z[1] = byte(extensionNextProtoNeg & 0xff)
147		// The length is always 0
148		z = z[4:]
149	}
150	if len(m.serverName) > 0 {
151		z[0] = byte(extensionServerName >> 8)
152		z[1] = byte(extensionServerName & 0xff)
153		l := len(m.serverName) + 5
154		z[2] = byte(l >> 8)
155		z[3] = byte(l)
156		z = z[4:]
157
158		// RFC 3546, section 3.1
159		//
160		// struct {
161		//     NameType name_type;
162		//     select (name_type) {
163		//         case host_name: HostName;
164		//     } name;
165		// } ServerName;
166		//
167		// enum {
168		//     host_name(0), (255)
169		// } NameType;
170		//
171		// opaque HostName<1..2^16-1>;
172		//
173		// struct {
174		//     ServerName server_name_list<1..2^16-1>
175		// } ServerNameList;
176
177		z[0] = byte((len(m.serverName) + 3) >> 8)
178		z[1] = byte(len(m.serverName) + 3)
179		z[3] = byte(len(m.serverName) >> 8)
180		z[4] = byte(len(m.serverName))
181		copy(z[5:], []byte(m.serverName))
182		z = z[l:]
183	}
184	if m.ocspStapling {
185		// RFC 4366, section 3.6
186		z[0] = byte(extensionStatusRequest >> 8)
187		z[1] = byte(extensionStatusRequest)
188		z[2] = 0
189		z[3] = 5
190		z[4] = 1 // OCSP type
191		// Two zero valued uint16s for the two lengths.
192		z = z[9:]
193	}
194	if len(m.supportedCurves) > 0 {
195		// http://tools.ietf.org/html/rfc4492#section-5.5.1
196		z[0] = byte(extensionSupportedCurves >> 8)
197		z[1] = byte(extensionSupportedCurves)
198		l := 2 + 2*len(m.supportedCurves)
199		z[2] = byte(l >> 8)
200		z[3] = byte(l)
201		l -= 2
202		z[4] = byte(l >> 8)
203		z[5] = byte(l)
204		z = z[6:]
205		for _, curve := range m.supportedCurves {
206			z[0] = byte(curve >> 8)
207			z[1] = byte(curve)
208			z = z[2:]
209		}
210	}
211	if len(m.supportedPoints) > 0 {
212		// http://tools.ietf.org/html/rfc4492#section-5.5.2
213		z[0] = byte(extensionSupportedPoints >> 8)
214		z[1] = byte(extensionSupportedPoints)
215		l := 1 + len(m.supportedPoints)
216		z[2] = byte(l >> 8)
217		z[3] = byte(l)
218		l--
219		z[4] = byte(l)
220		z = z[5:]
221		for _, pointFormat := range m.supportedPoints {
222			z[0] = pointFormat
223			z = z[1:]
224		}
225	}
226	if m.ticketSupported {
227		// http://tools.ietf.org/html/rfc5077#section-3.2
228		z[0] = byte(extensionSessionTicket >> 8)
229		z[1] = byte(extensionSessionTicket)
230		l := len(m.sessionTicket)
231		z[2] = byte(l >> 8)
232		z[3] = byte(l)
233		z = z[4:]
234		copy(z, m.sessionTicket)
235		z = z[len(m.sessionTicket):]
236	}
237	if len(m.supportedSignatureAlgorithms) > 0 {
238		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
239		z[0] = byte(extensionSignatureAlgorithms >> 8)
240		z[1] = byte(extensionSignatureAlgorithms)
241		l := 2 + 2*len(m.supportedSignatureAlgorithms)
242		z[2] = byte(l >> 8)
243		z[3] = byte(l)
244		z = z[4:]
245
246		l -= 2
247		z[0] = byte(l >> 8)
248		z[1] = byte(l)
249		z = z[2:]
250		for _, sigAlgo := range m.supportedSignatureAlgorithms {
251			z[0] = byte(sigAlgo >> 8)
252			z[1] = byte(sigAlgo)
253			z = z[2:]
254		}
255	}
256	if m.secureRenegotiationSupported {
257		z[0] = byte(extensionRenegotiationInfo >> 8)
258		z[1] = byte(extensionRenegotiationInfo & 0xff)
259		z[2] = 0
260		z[3] = byte(len(m.secureRenegotiation) + 1)
261		z[4] = byte(len(m.secureRenegotiation))
262		z = z[5:]
263		copy(z, m.secureRenegotiation)
264		z = z[len(m.secureRenegotiation):]
265	}
266	if len(m.alpnProtocols) > 0 {
267		z[0] = byte(extensionALPN >> 8)
268		z[1] = byte(extensionALPN & 0xff)
269		lengths := z[2:]
270		z = z[6:]
271
272		stringsLength := 0
273		for _, s := range m.alpnProtocols {
274			l := len(s)
275			z[0] = byte(l)
276			copy(z[1:], s)
277			z = z[1+l:]
278			stringsLength += 1 + l
279		}
280
281		lengths[2] = byte(stringsLength >> 8)
282		lengths[3] = byte(stringsLength)
283		stringsLength += 2
284		lengths[0] = byte(stringsLength >> 8)
285		lengths[1] = byte(stringsLength)
286	}
287	if m.scts {
288		// https://tools.ietf.org/html/rfc6962#section-3.3.1
289		z[0] = byte(extensionSCT >> 8)
290		z[1] = byte(extensionSCT)
291		// zero uint16 for the zero-length extension_data
292		z = z[4:]
293	}
294
295	m.raw = x
296
297	return x
298}
299
300func (m *clientHelloMsg) unmarshal(data []byte) bool {
301	if len(data) < 42 {
302		return false
303	}
304	m.raw = data
305	m.vers = uint16(data[4])<<8 | uint16(data[5])
306	m.random = data[6:38]
307	sessionIdLen := int(data[38])
308	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
309		return false
310	}
311	m.sessionId = data[39 : 39+sessionIdLen]
312	data = data[39+sessionIdLen:]
313	if len(data) < 2 {
314		return false
315	}
316	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
317	// they are uint16s, the number must be even.
318	cipherSuiteLen := int(data[0])<<8 | int(data[1])
319	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
320		return false
321	}
322	numCipherSuites := cipherSuiteLen / 2
323	m.cipherSuites = make([]uint16, numCipherSuites)
324	for i := 0; i < numCipherSuites; i++ {
325		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
326		if m.cipherSuites[i] == scsvRenegotiation {
327			m.secureRenegotiationSupported = true
328		}
329	}
330	data = data[2+cipherSuiteLen:]
331	if len(data) < 1 {
332		return false
333	}
334	compressionMethodsLen := int(data[0])
335	if len(data) < 1+compressionMethodsLen {
336		return false
337	}
338	m.compressionMethods = data[1 : 1+compressionMethodsLen]
339
340	data = data[1+compressionMethodsLen:]
341
342	m.nextProtoNeg = false
343	m.serverName = ""
344	m.ocspStapling = false
345	m.ticketSupported = false
346	m.sessionTicket = nil
347	m.supportedSignatureAlgorithms = nil
348	m.alpnProtocols = nil
349	m.scts = false
350
351	if len(data) == 0 {
352		// ClientHello is optionally followed by extension data
353		return true
354	}
355	if len(data) < 2 {
356		return false
357	}
358
359	extensionsLength := int(data[0])<<8 | int(data[1])
360	data = data[2:]
361	if extensionsLength != len(data) {
362		return false
363	}
364
365	for len(data) != 0 {
366		if len(data) < 4 {
367			return false
368		}
369		extension := uint16(data[0])<<8 | uint16(data[1])
370		length := int(data[2])<<8 | int(data[3])
371		data = data[4:]
372		if len(data) < length {
373			return false
374		}
375
376		switch extension {
377		case extensionServerName:
378			d := data[:length]
379			if len(d) < 2 {
380				return false
381			}
382			namesLen := int(d[0])<<8 | int(d[1])
383			d = d[2:]
384			if len(d) != namesLen {
385				return false
386			}
387			for len(d) > 0 {
388				if len(d) < 3 {
389					return false
390				}
391				nameType := d[0]
392				nameLen := int(d[1])<<8 | int(d[2])
393				d = d[3:]
394				if len(d) < nameLen {
395					return false
396				}
397				if nameType == 0 {
398					m.serverName = string(d[:nameLen])
399					// An SNI value may not include a
400					// trailing dot. See
401					// https://tools.ietf.org/html/rfc6066#section-3.
402					if strings.HasSuffix(m.serverName, ".") {
403						return false
404					}
405					break
406				}
407				d = d[nameLen:]
408			}
409		case extensionNextProtoNeg:
410			if length > 0 {
411				return false
412			}
413			m.nextProtoNeg = true
414		case extensionStatusRequest:
415			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
416		case extensionSupportedCurves:
417			// http://tools.ietf.org/html/rfc4492#section-5.5.1
418			if length < 2 {
419				return false
420			}
421			l := int(data[0])<<8 | int(data[1])
422			if l%2 == 1 || length != l+2 {
423				return false
424			}
425			numCurves := l / 2
426			m.supportedCurves = make([]CurveID, numCurves)
427			d := data[2:]
428			for i := 0; i < numCurves; i++ {
429				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
430				d = d[2:]
431			}
432		case extensionSupportedPoints:
433			// http://tools.ietf.org/html/rfc4492#section-5.5.2
434			if length < 1 {
435				return false
436			}
437			l := int(data[0])
438			if length != l+1 {
439				return false
440			}
441			m.supportedPoints = make([]uint8, l)
442			copy(m.supportedPoints, data[1:])
443		case extensionSessionTicket:
444			// http://tools.ietf.org/html/rfc5077#section-3.2
445			m.ticketSupported = true
446			m.sessionTicket = data[:length]
447		case extensionSignatureAlgorithms:
448			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
449			if length < 2 || length&1 != 0 {
450				return false
451			}
452			l := int(data[0])<<8 | int(data[1])
453			if l != length-2 {
454				return false
455			}
456			n := l / 2
457			d := data[2:]
458			m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
459			for i := range m.supportedSignatureAlgorithms {
460				m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
461				d = d[2:]
462			}
463		case extensionRenegotiationInfo:
464			if length == 0 {
465				return false
466			}
467			d := data[:length]
468			l := int(d[0])
469			d = d[1:]
470			if l != len(d) {
471				return false
472			}
473
474			m.secureRenegotiation = d
475			m.secureRenegotiationSupported = true
476		case extensionALPN:
477			if length < 2 {
478				return false
479			}
480			l := int(data[0])<<8 | int(data[1])
481			if l != length-2 {
482				return false
483			}
484			d := data[2:length]
485			for len(d) != 0 {
486				stringLen := int(d[0])
487				d = d[1:]
488				if stringLen == 0 || stringLen > len(d) {
489					return false
490				}
491				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
492				d = d[stringLen:]
493			}
494		case extensionSCT:
495			m.scts = true
496			if length != 0 {
497				return false
498			}
499		}
500		data = data[length:]
501	}
502
503	return true
504}
505
506type serverHelloMsg struct {
507	raw                          []byte
508	vers                         uint16
509	random                       []byte
510	sessionId                    []byte
511	cipherSuite                  uint16
512	compressionMethod            uint8
513	nextProtoNeg                 bool
514	nextProtos                   []string
515	ocspStapling                 bool
516	scts                         [][]byte
517	ticketSupported              bool
518	secureRenegotiation          []byte
519	secureRenegotiationSupported bool
520	alpnProtocol                 string
521}
522
523func (m *serverHelloMsg) equal(i interface{}) bool {
524	m1, ok := i.(*serverHelloMsg)
525	if !ok {
526		return false
527	}
528
529	if len(m.scts) != len(m1.scts) {
530		return false
531	}
532	for i, sct := range m.scts {
533		if !bytes.Equal(sct, m1.scts[i]) {
534			return false
535		}
536	}
537
538	return bytes.Equal(m.raw, m1.raw) &&
539		m.vers == m1.vers &&
540		bytes.Equal(m.random, m1.random) &&
541		bytes.Equal(m.sessionId, m1.sessionId) &&
542		m.cipherSuite == m1.cipherSuite &&
543		m.compressionMethod == m1.compressionMethod &&
544		m.nextProtoNeg == m1.nextProtoNeg &&
545		eqStrings(m.nextProtos, m1.nextProtos) &&
546		m.ocspStapling == m1.ocspStapling &&
547		m.ticketSupported == m1.ticketSupported &&
548		m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
549		bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
550		m.alpnProtocol == m1.alpnProtocol
551}
552
553func (m *serverHelloMsg) marshal() []byte {
554	if m.raw != nil {
555		return m.raw
556	}
557
558	length := 38 + len(m.sessionId)
559	numExtensions := 0
560	extensionsLength := 0
561
562	nextProtoLen := 0
563	if m.nextProtoNeg {
564		numExtensions++
565		for _, v := range m.nextProtos {
566			nextProtoLen += len(v)
567		}
568		nextProtoLen += len(m.nextProtos)
569		extensionsLength += nextProtoLen
570	}
571	if m.ocspStapling {
572		numExtensions++
573	}
574	if m.ticketSupported {
575		numExtensions++
576	}
577	if m.secureRenegotiationSupported {
578		extensionsLength += 1 + len(m.secureRenegotiation)
579		numExtensions++
580	}
581	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
582		if alpnLen >= 256 {
583			panic("invalid ALPN protocol")
584		}
585		extensionsLength += 2 + 1 + alpnLen
586		numExtensions++
587	}
588	sctLen := 0
589	if len(m.scts) > 0 {
590		for _, sct := range m.scts {
591			sctLen += len(sct) + 2
592		}
593		extensionsLength += 2 + sctLen
594		numExtensions++
595	}
596
597	if numExtensions > 0 {
598		extensionsLength += 4 * numExtensions
599		length += 2 + extensionsLength
600	}
601
602	x := make([]byte, 4+length)
603	x[0] = typeServerHello
604	x[1] = uint8(length >> 16)
605	x[2] = uint8(length >> 8)
606	x[3] = uint8(length)
607	x[4] = uint8(m.vers >> 8)
608	x[5] = uint8(m.vers)
609	copy(x[6:38], m.random)
610	x[38] = uint8(len(m.sessionId))
611	copy(x[39:39+len(m.sessionId)], m.sessionId)
612	z := x[39+len(m.sessionId):]
613	z[0] = uint8(m.cipherSuite >> 8)
614	z[1] = uint8(m.cipherSuite)
615	z[2] = m.compressionMethod
616
617	z = z[3:]
618	if numExtensions > 0 {
619		z[0] = byte(extensionsLength >> 8)
620		z[1] = byte(extensionsLength)
621		z = z[2:]
622	}
623	if m.nextProtoNeg {
624		z[0] = byte(extensionNextProtoNeg >> 8)
625		z[1] = byte(extensionNextProtoNeg & 0xff)
626		z[2] = byte(nextProtoLen >> 8)
627		z[3] = byte(nextProtoLen)
628		z = z[4:]
629
630		for _, v := range m.nextProtos {
631			l := len(v)
632			if l > 255 {
633				l = 255
634			}
635			z[0] = byte(l)
636			copy(z[1:], []byte(v[0:l]))
637			z = z[1+l:]
638		}
639	}
640	if m.ocspStapling {
641		z[0] = byte(extensionStatusRequest >> 8)
642		z[1] = byte(extensionStatusRequest)
643		z = z[4:]
644	}
645	if m.ticketSupported {
646		z[0] = byte(extensionSessionTicket >> 8)
647		z[1] = byte(extensionSessionTicket)
648		z = z[4:]
649	}
650	if m.secureRenegotiationSupported {
651		z[0] = byte(extensionRenegotiationInfo >> 8)
652		z[1] = byte(extensionRenegotiationInfo & 0xff)
653		z[2] = 0
654		z[3] = byte(len(m.secureRenegotiation) + 1)
655		z[4] = byte(len(m.secureRenegotiation))
656		z = z[5:]
657		copy(z, m.secureRenegotiation)
658		z = z[len(m.secureRenegotiation):]
659	}
660	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
661		z[0] = byte(extensionALPN >> 8)
662		z[1] = byte(extensionALPN & 0xff)
663		l := 2 + 1 + alpnLen
664		z[2] = byte(l >> 8)
665		z[3] = byte(l)
666		l -= 2
667		z[4] = byte(l >> 8)
668		z[5] = byte(l)
669		l -= 1
670		z[6] = byte(l)
671		copy(z[7:], []byte(m.alpnProtocol))
672		z = z[7+alpnLen:]
673	}
674	if sctLen > 0 {
675		z[0] = byte(extensionSCT >> 8)
676		z[1] = byte(extensionSCT)
677		l := sctLen + 2
678		z[2] = byte(l >> 8)
679		z[3] = byte(l)
680		z[4] = byte(sctLen >> 8)
681		z[5] = byte(sctLen)
682
683		z = z[6:]
684		for _, sct := range m.scts {
685			z[0] = byte(len(sct) >> 8)
686			z[1] = byte(len(sct))
687			copy(z[2:], sct)
688			z = z[len(sct)+2:]
689		}
690	}
691
692	m.raw = x
693
694	return x
695}
696
697func (m *serverHelloMsg) unmarshal(data []byte) bool {
698	if len(data) < 42 {
699		return false
700	}
701	m.raw = data
702	m.vers = uint16(data[4])<<8 | uint16(data[5])
703	m.random = data[6:38]
704	sessionIdLen := int(data[38])
705	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
706		return false
707	}
708	m.sessionId = data[39 : 39+sessionIdLen]
709	data = data[39+sessionIdLen:]
710	if len(data) < 3 {
711		return false
712	}
713	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
714	m.compressionMethod = data[2]
715	data = data[3:]
716
717	m.nextProtoNeg = false
718	m.nextProtos = nil
719	m.ocspStapling = false
720	m.scts = nil
721	m.ticketSupported = false
722	m.alpnProtocol = ""
723
724	if len(data) == 0 {
725		// ServerHello is optionally followed by extension data
726		return true
727	}
728	if len(data) < 2 {
729		return false
730	}
731
732	extensionsLength := int(data[0])<<8 | int(data[1])
733	data = data[2:]
734	if len(data) != extensionsLength {
735		return false
736	}
737
738	for len(data) != 0 {
739		if len(data) < 4 {
740			return false
741		}
742		extension := uint16(data[0])<<8 | uint16(data[1])
743		length := int(data[2])<<8 | int(data[3])
744		data = data[4:]
745		if len(data) < length {
746			return false
747		}
748
749		switch extension {
750		case extensionNextProtoNeg:
751			m.nextProtoNeg = true
752			d := data[:length]
753			for len(d) > 0 {
754				l := int(d[0])
755				d = d[1:]
756				if l == 0 || l > len(d) {
757					return false
758				}
759				m.nextProtos = append(m.nextProtos, string(d[:l]))
760				d = d[l:]
761			}
762		case extensionStatusRequest:
763			if length > 0 {
764				return false
765			}
766			m.ocspStapling = true
767		case extensionSessionTicket:
768			if length > 0 {
769				return false
770			}
771			m.ticketSupported = true
772		case extensionRenegotiationInfo:
773			if length == 0 {
774				return false
775			}
776			d := data[:length]
777			l := int(d[0])
778			d = d[1:]
779			if l != len(d) {
780				return false
781			}
782
783			m.secureRenegotiation = d
784			m.secureRenegotiationSupported = true
785		case extensionALPN:
786			d := data[:length]
787			if len(d) < 3 {
788				return false
789			}
790			l := int(d[0])<<8 | int(d[1])
791			if l != len(d)-2 {
792				return false
793			}
794			d = d[2:]
795			l = int(d[0])
796			if l != len(d)-1 {
797				return false
798			}
799			d = d[1:]
800			if len(d) == 0 {
801				// ALPN protocols must not be empty.
802				return false
803			}
804			m.alpnProtocol = string(d)
805		case extensionSCT:
806			d := data[:length]
807
808			if len(d) < 2 {
809				return false
810			}
811			l := int(d[0])<<8 | int(d[1])
812			d = d[2:]
813			if len(d) != l || l == 0 {
814				return false
815			}
816
817			m.scts = make([][]byte, 0, 3)
818			for len(d) != 0 {
819				if len(d) < 2 {
820					return false
821				}
822				sctLen := int(d[0])<<8 | int(d[1])
823				d = d[2:]
824				if sctLen == 0 || len(d) < sctLen {
825					return false
826				}
827				m.scts = append(m.scts, d[:sctLen])
828				d = d[sctLen:]
829			}
830		}
831		data = data[length:]
832	}
833
834	return true
835}
836
837type certificateMsg struct {
838	raw          []byte
839	certificates [][]byte
840}
841
842func (m *certificateMsg) equal(i interface{}) bool {
843	m1, ok := i.(*certificateMsg)
844	if !ok {
845		return false
846	}
847
848	return bytes.Equal(m.raw, m1.raw) &&
849		eqByteSlices(m.certificates, m1.certificates)
850}
851
852func (m *certificateMsg) marshal() (x []byte) {
853	if m.raw != nil {
854		return m.raw
855	}
856
857	var i int
858	for _, slice := range m.certificates {
859		i += len(slice)
860	}
861
862	length := 3 + 3*len(m.certificates) + i
863	x = make([]byte, 4+length)
864	x[0] = typeCertificate
865	x[1] = uint8(length >> 16)
866	x[2] = uint8(length >> 8)
867	x[3] = uint8(length)
868
869	certificateOctets := length - 3
870	x[4] = uint8(certificateOctets >> 16)
871	x[5] = uint8(certificateOctets >> 8)
872	x[6] = uint8(certificateOctets)
873
874	y := x[7:]
875	for _, slice := range m.certificates {
876		y[0] = uint8(len(slice) >> 16)
877		y[1] = uint8(len(slice) >> 8)
878		y[2] = uint8(len(slice))
879		copy(y[3:], slice)
880		y = y[3+len(slice):]
881	}
882
883	m.raw = x
884	return
885}
886
887func (m *certificateMsg) unmarshal(data []byte) bool {
888	if len(data) < 7 {
889		return false
890	}
891
892	m.raw = data
893	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
894	if uint32(len(data)) != certsLen+7 {
895		return false
896	}
897
898	numCerts := 0
899	d := data[7:]
900	for certsLen > 0 {
901		if len(d) < 4 {
902			return false
903		}
904		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
905		if uint32(len(d)) < 3+certLen {
906			return false
907		}
908		d = d[3+certLen:]
909		certsLen -= 3 + certLen
910		numCerts++
911	}
912
913	m.certificates = make([][]byte, numCerts)
914	d = data[7:]
915	for i := 0; i < numCerts; i++ {
916		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
917		m.certificates[i] = d[3 : 3+certLen]
918		d = d[3+certLen:]
919	}
920
921	return true
922}
923
924type serverKeyExchangeMsg struct {
925	raw []byte
926	key []byte
927}
928
929func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
930	m1, ok := i.(*serverKeyExchangeMsg)
931	if !ok {
932		return false
933	}
934
935	return bytes.Equal(m.raw, m1.raw) &&
936		bytes.Equal(m.key, m1.key)
937}
938
939func (m *serverKeyExchangeMsg) marshal() []byte {
940	if m.raw != nil {
941		return m.raw
942	}
943	length := len(m.key)
944	x := make([]byte, length+4)
945	x[0] = typeServerKeyExchange
946	x[1] = uint8(length >> 16)
947	x[2] = uint8(length >> 8)
948	x[3] = uint8(length)
949	copy(x[4:], m.key)
950
951	m.raw = x
952	return x
953}
954
955func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
956	m.raw = data
957	if len(data) < 4 {
958		return false
959	}
960	m.key = data[4:]
961	return true
962}
963
964type certificateStatusMsg struct {
965	raw        []byte
966	statusType uint8
967	response   []byte
968}
969
970func (m *certificateStatusMsg) equal(i interface{}) bool {
971	m1, ok := i.(*certificateStatusMsg)
972	if !ok {
973		return false
974	}
975
976	return bytes.Equal(m.raw, m1.raw) &&
977		m.statusType == m1.statusType &&
978		bytes.Equal(m.response, m1.response)
979}
980
981func (m *certificateStatusMsg) marshal() []byte {
982	if m.raw != nil {
983		return m.raw
984	}
985
986	var x []byte
987	if m.statusType == statusTypeOCSP {
988		x = make([]byte, 4+4+len(m.response))
989		x[0] = typeCertificateStatus
990		l := len(m.response) + 4
991		x[1] = byte(l >> 16)
992		x[2] = byte(l >> 8)
993		x[3] = byte(l)
994		x[4] = statusTypeOCSP
995
996		l -= 4
997		x[5] = byte(l >> 16)
998		x[6] = byte(l >> 8)
999		x[7] = byte(l)
1000		copy(x[8:], m.response)
1001	} else {
1002		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
1003	}
1004
1005	m.raw = x
1006	return x
1007}
1008
1009func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1010	m.raw = data
1011	if len(data) < 5 {
1012		return false
1013	}
1014	m.statusType = data[4]
1015
1016	m.response = nil
1017	if m.statusType == statusTypeOCSP {
1018		if len(data) < 8 {
1019			return false
1020		}
1021		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
1022		if uint32(len(data)) != 4+4+respLen {
1023			return false
1024		}
1025		m.response = data[8:]
1026	}
1027	return true
1028}
1029
1030type serverHelloDoneMsg struct{}
1031
1032func (m *serverHelloDoneMsg) equal(i interface{}) bool {
1033	_, ok := i.(*serverHelloDoneMsg)
1034	return ok
1035}
1036
1037func (m *serverHelloDoneMsg) marshal() []byte {
1038	x := make([]byte, 4)
1039	x[0] = typeServerHelloDone
1040	return x
1041}
1042
1043func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1044	return len(data) == 4
1045}
1046
1047type clientKeyExchangeMsg struct {
1048	raw        []byte
1049	ciphertext []byte
1050}
1051
1052func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1053	m1, ok := i.(*clientKeyExchangeMsg)
1054	if !ok {
1055		return false
1056	}
1057
1058	return bytes.Equal(m.raw, m1.raw) &&
1059		bytes.Equal(m.ciphertext, m1.ciphertext)
1060}
1061
1062func (m *clientKeyExchangeMsg) marshal() []byte {
1063	if m.raw != nil {
1064		return m.raw
1065	}
1066	length := len(m.ciphertext)
1067	x := make([]byte, length+4)
1068	x[0] = typeClientKeyExchange
1069	x[1] = uint8(length >> 16)
1070	x[2] = uint8(length >> 8)
1071	x[3] = uint8(length)
1072	copy(x[4:], m.ciphertext)
1073
1074	m.raw = x
1075	return x
1076}
1077
1078func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1079	m.raw = data
1080	if len(data) < 4 {
1081		return false
1082	}
1083	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1084	if l != len(data)-4 {
1085		return false
1086	}
1087	m.ciphertext = data[4:]
1088	return true
1089}
1090
1091type finishedMsg struct {
1092	raw        []byte
1093	verifyData []byte
1094}
1095
1096func (m *finishedMsg) equal(i interface{}) bool {
1097	m1, ok := i.(*finishedMsg)
1098	if !ok {
1099		return false
1100	}
1101
1102	return bytes.Equal(m.raw, m1.raw) &&
1103		bytes.Equal(m.verifyData, m1.verifyData)
1104}
1105
1106func (m *finishedMsg) marshal() (x []byte) {
1107	if m.raw != nil {
1108		return m.raw
1109	}
1110
1111	x = make([]byte, 4+len(m.verifyData))
1112	x[0] = typeFinished
1113	x[3] = byte(len(m.verifyData))
1114	copy(x[4:], m.verifyData)
1115	m.raw = x
1116	return
1117}
1118
1119func (m *finishedMsg) unmarshal(data []byte) bool {
1120	m.raw = data
1121	if len(data) < 4 {
1122		return false
1123	}
1124	m.verifyData = data[4:]
1125	return true
1126}
1127
1128type nextProtoMsg struct {
1129	raw   []byte
1130	proto string
1131}
1132
1133func (m *nextProtoMsg) equal(i interface{}) bool {
1134	m1, ok := i.(*nextProtoMsg)
1135	if !ok {
1136		return false
1137	}
1138
1139	return bytes.Equal(m.raw, m1.raw) &&
1140		m.proto == m1.proto
1141}
1142
1143func (m *nextProtoMsg) marshal() []byte {
1144	if m.raw != nil {
1145		return m.raw
1146	}
1147	l := len(m.proto)
1148	if l > 255 {
1149		l = 255
1150	}
1151
1152	padding := 32 - (l+2)%32
1153	length := l + padding + 2
1154	x := make([]byte, length+4)
1155	x[0] = typeNextProtocol
1156	x[1] = uint8(length >> 16)
1157	x[2] = uint8(length >> 8)
1158	x[3] = uint8(length)
1159
1160	y := x[4:]
1161	y[0] = byte(l)
1162	copy(y[1:], []byte(m.proto[0:l]))
1163	y = y[1+l:]
1164	y[0] = byte(padding)
1165
1166	m.raw = x
1167
1168	return x
1169}
1170
1171func (m *nextProtoMsg) unmarshal(data []byte) bool {
1172	m.raw = data
1173
1174	if len(data) < 5 {
1175		return false
1176	}
1177	data = data[4:]
1178	protoLen := int(data[0])
1179	data = data[1:]
1180	if len(data) < protoLen {
1181		return false
1182	}
1183	m.proto = string(data[0:protoLen])
1184	data = data[protoLen:]
1185
1186	if len(data) < 1 {
1187		return false
1188	}
1189	paddingLen := int(data[0])
1190	data = data[1:]
1191	if len(data) != paddingLen {
1192		return false
1193	}
1194
1195	return true
1196}
1197
1198type certificateRequestMsg struct {
1199	raw []byte
1200	// hasSignatureAndHash indicates whether this message includes a list
1201	// of signature and hash functions. This change was introduced with TLS
1202	// 1.2.
1203	hasSignatureAndHash bool
1204
1205	certificateTypes             []byte
1206	supportedSignatureAlgorithms []SignatureScheme
1207	certificateAuthorities       [][]byte
1208}
1209
1210func (m *certificateRequestMsg) equal(i interface{}) bool {
1211	m1, ok := i.(*certificateRequestMsg)
1212	if !ok {
1213		return false
1214	}
1215
1216	return bytes.Equal(m.raw, m1.raw) &&
1217		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1218		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1219		eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
1220}
1221
1222func (m *certificateRequestMsg) marshal() (x []byte) {
1223	if m.raw != nil {
1224		return m.raw
1225	}
1226
1227	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
1228	length := 1 + len(m.certificateTypes) + 2
1229	casLength := 0
1230	for _, ca := range m.certificateAuthorities {
1231		casLength += 2 + len(ca)
1232	}
1233	length += casLength
1234
1235	if m.hasSignatureAndHash {
1236		length += 2 + 2*len(m.supportedSignatureAlgorithms)
1237	}
1238
1239	x = make([]byte, 4+length)
1240	x[0] = typeCertificateRequest
1241	x[1] = uint8(length >> 16)
1242	x[2] = uint8(length >> 8)
1243	x[3] = uint8(length)
1244
1245	x[4] = uint8(len(m.certificateTypes))
1246
1247	copy(x[5:], m.certificateTypes)
1248	y := x[5+len(m.certificateTypes):]
1249
1250	if m.hasSignatureAndHash {
1251		n := len(m.supportedSignatureAlgorithms) * 2
1252		y[0] = uint8(n >> 8)
1253		y[1] = uint8(n)
1254		y = y[2:]
1255		for _, sigAlgo := range m.supportedSignatureAlgorithms {
1256			y[0] = uint8(sigAlgo >> 8)
1257			y[1] = uint8(sigAlgo)
1258			y = y[2:]
1259		}
1260	}
1261
1262	y[0] = uint8(casLength >> 8)
1263	y[1] = uint8(casLength)
1264	y = y[2:]
1265	for _, ca := range m.certificateAuthorities {
1266		y[0] = uint8(len(ca) >> 8)
1267		y[1] = uint8(len(ca))
1268		y = y[2:]
1269		copy(y, ca)
1270		y = y[len(ca):]
1271	}
1272
1273	m.raw = x
1274	return
1275}
1276
1277func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1278	m.raw = data
1279
1280	if len(data) < 5 {
1281		return false
1282	}
1283
1284	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1285	if uint32(len(data))-4 != length {
1286		return false
1287	}
1288
1289	numCertTypes := int(data[4])
1290	data = data[5:]
1291	if numCertTypes == 0 || len(data) <= numCertTypes {
1292		return false
1293	}
1294
1295	m.certificateTypes = make([]byte, numCertTypes)
1296	if copy(m.certificateTypes, data) != numCertTypes {
1297		return false
1298	}
1299
1300	data = data[numCertTypes:]
1301
1302	if m.hasSignatureAndHash {
1303		if len(data) < 2 {
1304			return false
1305		}
1306		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1307		data = data[2:]
1308		if sigAndHashLen&1 != 0 {
1309			return false
1310		}
1311		if len(data) < int(sigAndHashLen) {
1312			return false
1313		}
1314		numSigAlgos := sigAndHashLen / 2
1315		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
1316		for i := range m.supportedSignatureAlgorithms {
1317			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1318			data = data[2:]
1319		}
1320	}
1321
1322	if len(data) < 2 {
1323		return false
1324	}
1325	casLength := uint16(data[0])<<8 | uint16(data[1])
1326	data = data[2:]
1327	if len(data) < int(casLength) {
1328		return false
1329	}
1330	cas := make([]byte, casLength)
1331	copy(cas, data)
1332	data = data[casLength:]
1333
1334	m.certificateAuthorities = nil
1335	for len(cas) > 0 {
1336		if len(cas) < 2 {
1337			return false
1338		}
1339		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1340		cas = cas[2:]
1341
1342		if len(cas) < int(caLen) {
1343			return false
1344		}
1345
1346		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1347		cas = cas[caLen:]
1348	}
1349
1350	return len(data) == 0
1351}
1352
1353type certificateVerifyMsg struct {
1354	raw                 []byte
1355	hasSignatureAndHash bool
1356	signatureAlgorithm  SignatureScheme
1357	signature           []byte
1358}
1359
1360func (m *certificateVerifyMsg) equal(i interface{}) bool {
1361	m1, ok := i.(*certificateVerifyMsg)
1362	if !ok {
1363		return false
1364	}
1365
1366	return bytes.Equal(m.raw, m1.raw) &&
1367		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1368		m.signatureAlgorithm == m1.signatureAlgorithm &&
1369		bytes.Equal(m.signature, m1.signature)
1370}
1371
1372func (m *certificateVerifyMsg) marshal() (x []byte) {
1373	if m.raw != nil {
1374		return m.raw
1375	}
1376
1377	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
1378	siglength := len(m.signature)
1379	length := 2 + siglength
1380	if m.hasSignatureAndHash {
1381		length += 2
1382	}
1383	x = make([]byte, 4+length)
1384	x[0] = typeCertificateVerify
1385	x[1] = uint8(length >> 16)
1386	x[2] = uint8(length >> 8)
1387	x[3] = uint8(length)
1388	y := x[4:]
1389	if m.hasSignatureAndHash {
1390		y[0] = uint8(m.signatureAlgorithm >> 8)
1391		y[1] = uint8(m.signatureAlgorithm)
1392		y = y[2:]
1393	}
1394	y[0] = uint8(siglength >> 8)
1395	y[1] = uint8(siglength)
1396	copy(y[2:], m.signature)
1397
1398	m.raw = x
1399
1400	return
1401}
1402
1403func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1404	m.raw = data
1405
1406	if len(data) < 6 {
1407		return false
1408	}
1409
1410	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1411	if uint32(len(data))-4 != length {
1412		return false
1413	}
1414
1415	data = data[4:]
1416	if m.hasSignatureAndHash {
1417		m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1418		data = data[2:]
1419	}
1420
1421	if len(data) < 2 {
1422		return false
1423	}
1424	siglength := int(data[0])<<8 + int(data[1])
1425	data = data[2:]
1426	if len(data) != siglength {
1427		return false
1428	}
1429
1430	m.signature = data
1431
1432	return true
1433}
1434
1435type newSessionTicketMsg struct {
1436	raw    []byte
1437	ticket []byte
1438}
1439
1440func (m *newSessionTicketMsg) equal(i interface{}) bool {
1441	m1, ok := i.(*newSessionTicketMsg)
1442	if !ok {
1443		return false
1444	}
1445
1446	return bytes.Equal(m.raw, m1.raw) &&
1447		bytes.Equal(m.ticket, m1.ticket)
1448}
1449
1450func (m *newSessionTicketMsg) marshal() (x []byte) {
1451	if m.raw != nil {
1452		return m.raw
1453	}
1454
1455	// See http://tools.ietf.org/html/rfc5077#section-3.3
1456	ticketLen := len(m.ticket)
1457	length := 2 + 4 + ticketLen
1458	x = make([]byte, 4+length)
1459	x[0] = typeNewSessionTicket
1460	x[1] = uint8(length >> 16)
1461	x[2] = uint8(length >> 8)
1462	x[3] = uint8(length)
1463	x[8] = uint8(ticketLen >> 8)
1464	x[9] = uint8(ticketLen)
1465	copy(x[10:], m.ticket)
1466
1467	m.raw = x
1468
1469	return
1470}
1471
1472func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1473	m.raw = data
1474
1475	if len(data) < 10 {
1476		return false
1477	}
1478
1479	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1480	if uint32(len(data))-4 != length {
1481		return false
1482	}
1483
1484	ticketLen := int(data[8])<<8 + int(data[9])
1485	if len(data)-10 != ticketLen {
1486		return false
1487	}
1488
1489	m.ticket = data[10:]
1490
1491	return true
1492}
1493
1494type helloRequestMsg struct {
1495}
1496
1497func (*helloRequestMsg) marshal() []byte {
1498	return []byte{typeHelloRequest, 0, 0, 0}
1499}
1500
1501func (*helloRequestMsg) unmarshal(data []byte) bool {
1502	return len(data) == 4
1503}
1504
1505func eqUint16s(x, y []uint16) bool {
1506	if len(x) != len(y) {
1507		return false
1508	}
1509	for i, v := range x {
1510		if y[i] != v {
1511			return false
1512		}
1513	}
1514	return true
1515}
1516
1517func eqCurveIDs(x, y []CurveID) bool {
1518	if len(x) != len(y) {
1519		return false
1520	}
1521	for i, v := range x {
1522		if y[i] != v {
1523			return false
1524		}
1525	}
1526	return true
1527}
1528
1529func eqStrings(x, y []string) bool {
1530	if len(x) != len(y) {
1531		return false
1532	}
1533	for i, v := range x {
1534		if y[i] != v {
1535			return false
1536		}
1537	}
1538	return true
1539}
1540
1541func eqByteSlices(x, y [][]byte) bool {
1542	if len(x) != len(y) {
1543		return false
1544	}
1545	for i, v := range x {
1546		if !bytes.Equal(v, y[i]) {
1547			return false
1548		}
1549	}
1550	return true
1551}
1552
1553func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
1554	if len(x) != len(y) {
1555		return false
1556	}
1557	for i, v := range x {
1558		if v != y[i] {
1559			return false
1560		}
1561	}
1562	return true
1563}
1564