1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import "bytes"
8
9type clientHelloMsg struct {
10	raw                 []byte
11	vers                uint16
12	random              []byte
13	sessionId           []byte
14	cipherSuites        []uint16
15	compressionMethods  []uint8
16	nextProtoNeg        bool
17	serverName          string
18	ocspStapling        bool
19	scts                bool
20	supportedCurves     []CurveID
21	supportedPoints     []uint8
22	ticketSupported     bool
23	sessionTicket       []uint8
24	signatureAndHashes  []signatureAndHash
25	secureRenegotiation bool
26	alpnProtocols       []string
27}
28
29func (m *clientHelloMsg) equal(i interface{}) bool {
30	m1, ok := i.(*clientHelloMsg)
31	if !ok {
32		return false
33	}
34
35	return bytes.Equal(m.raw, m1.raw) &&
36		m.vers == m1.vers &&
37		bytes.Equal(m.random, m1.random) &&
38		bytes.Equal(m.sessionId, m1.sessionId) &&
39		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
40		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
41		m.nextProtoNeg == m1.nextProtoNeg &&
42		m.serverName == m1.serverName &&
43		m.ocspStapling == m1.ocspStapling &&
44		m.scts == m1.scts &&
45		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
46		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
47		m.ticketSupported == m1.ticketSupported &&
48		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
49		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
50		m.secureRenegotiation == m1.secureRenegotiation &&
51		eqStrings(m.alpnProtocols, m1.alpnProtocols)
52}
53
54func (m *clientHelloMsg) marshal() []byte {
55	if m.raw != nil {
56		return m.raw
57	}
58
59	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
60	numExtensions := 0
61	extensionsLength := 0
62	if m.nextProtoNeg {
63		numExtensions++
64	}
65	if m.ocspStapling {
66		extensionsLength += 1 + 2 + 2
67		numExtensions++
68	}
69	if len(m.serverName) > 0 {
70		extensionsLength += 5 + len(m.serverName)
71		numExtensions++
72	}
73	if len(m.supportedCurves) > 0 {
74		extensionsLength += 2 + 2*len(m.supportedCurves)
75		numExtensions++
76	}
77	if len(m.supportedPoints) > 0 {
78		extensionsLength += 1 + len(m.supportedPoints)
79		numExtensions++
80	}
81	if m.ticketSupported {
82		extensionsLength += len(m.sessionTicket)
83		numExtensions++
84	}
85	if len(m.signatureAndHashes) > 0 {
86		extensionsLength += 2 + 2*len(m.signatureAndHashes)
87		numExtensions++
88	}
89	if m.secureRenegotiation {
90		extensionsLength += 1
91		numExtensions++
92	}
93	if len(m.alpnProtocols) > 0 {
94		extensionsLength += 2
95		for _, s := range m.alpnProtocols {
96			if l := len(s); l == 0 || l > 255 {
97				panic("invalid ALPN protocol")
98			}
99			extensionsLength++
100			extensionsLength += len(s)
101		}
102		numExtensions++
103	}
104	if m.scts {
105		numExtensions++
106	}
107	if numExtensions > 0 {
108		extensionsLength += 4 * numExtensions
109		length += 2 + extensionsLength
110	}
111
112	x := make([]byte, 4+length)
113	x[0] = typeClientHello
114	x[1] = uint8(length >> 16)
115	x[2] = uint8(length >> 8)
116	x[3] = uint8(length)
117	x[4] = uint8(m.vers >> 8)
118	x[5] = uint8(m.vers)
119	copy(x[6:38], m.random)
120	x[38] = uint8(len(m.sessionId))
121	copy(x[39:39+len(m.sessionId)], m.sessionId)
122	y := x[39+len(m.sessionId):]
123	y[0] = uint8(len(m.cipherSuites) >> 7)
124	y[1] = uint8(len(m.cipherSuites) << 1)
125	for i, suite := range m.cipherSuites {
126		y[2+i*2] = uint8(suite >> 8)
127		y[3+i*2] = uint8(suite)
128	}
129	z := y[2+len(m.cipherSuites)*2:]
130	z[0] = uint8(len(m.compressionMethods))
131	copy(z[1:], m.compressionMethods)
132
133	z = z[1+len(m.compressionMethods):]
134	if numExtensions > 0 {
135		z[0] = byte(extensionsLength >> 8)
136		z[1] = byte(extensionsLength)
137		z = z[2:]
138	}
139	if m.nextProtoNeg {
140		z[0] = byte(extensionNextProtoNeg >> 8)
141		z[1] = byte(extensionNextProtoNeg & 0xff)
142		// The length is always 0
143		z = z[4:]
144	}
145	if len(m.serverName) > 0 {
146		z[0] = byte(extensionServerName >> 8)
147		z[1] = byte(extensionServerName & 0xff)
148		l := len(m.serverName) + 5
149		z[2] = byte(l >> 8)
150		z[3] = byte(l)
151		z = z[4:]
152
153		// RFC 3546, section 3.1
154		//
155		// struct {
156		//     NameType name_type;
157		//     select (name_type) {
158		//         case host_name: HostName;
159		//     } name;
160		// } ServerName;
161		//
162		// enum {
163		//     host_name(0), (255)
164		// } NameType;
165		//
166		// opaque HostName<1..2^16-1>;
167		//
168		// struct {
169		//     ServerName server_name_list<1..2^16-1>
170		// } ServerNameList;
171
172		z[0] = byte((len(m.serverName) + 3) >> 8)
173		z[1] = byte(len(m.serverName) + 3)
174		z[3] = byte(len(m.serverName) >> 8)
175		z[4] = byte(len(m.serverName))
176		copy(z[5:], []byte(m.serverName))
177		z = z[l:]
178	}
179	if m.ocspStapling {
180		// RFC 4366, section 3.6
181		z[0] = byte(extensionStatusRequest >> 8)
182		z[1] = byte(extensionStatusRequest)
183		z[2] = 0
184		z[3] = 5
185		z[4] = 1 // OCSP type
186		// Two zero valued uint16s for the two lengths.
187		z = z[9:]
188	}
189	if len(m.supportedCurves) > 0 {
190		// http://tools.ietf.org/html/rfc4492#section-5.5.1
191		z[0] = byte(extensionSupportedCurves >> 8)
192		z[1] = byte(extensionSupportedCurves)
193		l := 2 + 2*len(m.supportedCurves)
194		z[2] = byte(l >> 8)
195		z[3] = byte(l)
196		l -= 2
197		z[4] = byte(l >> 8)
198		z[5] = byte(l)
199		z = z[6:]
200		for _, curve := range m.supportedCurves {
201			z[0] = byte(curve >> 8)
202			z[1] = byte(curve)
203			z = z[2:]
204		}
205	}
206	if len(m.supportedPoints) > 0 {
207		// http://tools.ietf.org/html/rfc4492#section-5.5.2
208		z[0] = byte(extensionSupportedPoints >> 8)
209		z[1] = byte(extensionSupportedPoints)
210		l := 1 + len(m.supportedPoints)
211		z[2] = byte(l >> 8)
212		z[3] = byte(l)
213		l--
214		z[4] = byte(l)
215		z = z[5:]
216		for _, pointFormat := range m.supportedPoints {
217			z[0] = byte(pointFormat)
218			z = z[1:]
219		}
220	}
221	if m.ticketSupported {
222		// http://tools.ietf.org/html/rfc5077#section-3.2
223		z[0] = byte(extensionSessionTicket >> 8)
224		z[1] = byte(extensionSessionTicket)
225		l := len(m.sessionTicket)
226		z[2] = byte(l >> 8)
227		z[3] = byte(l)
228		z = z[4:]
229		copy(z, m.sessionTicket)
230		z = z[len(m.sessionTicket):]
231	}
232	if len(m.signatureAndHashes) > 0 {
233		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
234		z[0] = byte(extensionSignatureAlgorithms >> 8)
235		z[1] = byte(extensionSignatureAlgorithms)
236		l := 2 + 2*len(m.signatureAndHashes)
237		z[2] = byte(l >> 8)
238		z[3] = byte(l)
239		z = z[4:]
240
241		l -= 2
242		z[0] = byte(l >> 8)
243		z[1] = byte(l)
244		z = z[2:]
245		for _, sigAndHash := range m.signatureAndHashes {
246			z[0] = sigAndHash.hash
247			z[1] = sigAndHash.signature
248			z = z[2:]
249		}
250	}
251	if m.secureRenegotiation {
252		z[0] = byte(extensionRenegotiationInfo >> 8)
253		z[1] = byte(extensionRenegotiationInfo & 0xff)
254		z[2] = 0
255		z[3] = 1
256		z = z[5:]
257	}
258	if len(m.alpnProtocols) > 0 {
259		z[0] = byte(extensionALPN >> 8)
260		z[1] = byte(extensionALPN & 0xff)
261		lengths := z[2:]
262		z = z[6:]
263
264		stringsLength := 0
265		for _, s := range m.alpnProtocols {
266			l := len(s)
267			z[0] = byte(l)
268			copy(z[1:], s)
269			z = z[1+l:]
270			stringsLength += 1 + l
271		}
272
273		lengths[2] = byte(stringsLength >> 8)
274		lengths[3] = byte(stringsLength)
275		stringsLength += 2
276		lengths[0] = byte(stringsLength >> 8)
277		lengths[1] = byte(stringsLength)
278	}
279	if m.scts {
280		// https://tools.ietf.org/html/rfc6962#section-3.3.1
281		z[0] = byte(extensionSCT >> 8)
282		z[1] = byte(extensionSCT)
283		// zero uint16 for the zero-length extension_data
284		z = z[4:]
285	}
286
287	m.raw = x
288
289	return x
290}
291
292func (m *clientHelloMsg) unmarshal(data []byte) bool {
293	if len(data) < 42 {
294		return false
295	}
296	m.raw = data
297	m.vers = uint16(data[4])<<8 | uint16(data[5])
298	m.random = data[6:38]
299	sessionIdLen := int(data[38])
300	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
301		return false
302	}
303	m.sessionId = data[39 : 39+sessionIdLen]
304	data = data[39+sessionIdLen:]
305	if len(data) < 2 {
306		return false
307	}
308	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
309	// they are uint16s, the number must be even.
310	cipherSuiteLen := int(data[0])<<8 | int(data[1])
311	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
312		return false
313	}
314	numCipherSuites := cipherSuiteLen / 2
315	m.cipherSuites = make([]uint16, numCipherSuites)
316	for i := 0; i < numCipherSuites; i++ {
317		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
318		if m.cipherSuites[i] == scsvRenegotiation {
319			m.secureRenegotiation = true
320		}
321	}
322	data = data[2+cipherSuiteLen:]
323	if len(data) < 1 {
324		return false
325	}
326	compressionMethodsLen := int(data[0])
327	if len(data) < 1+compressionMethodsLen {
328		return false
329	}
330	m.compressionMethods = data[1 : 1+compressionMethodsLen]
331
332	data = data[1+compressionMethodsLen:]
333
334	m.nextProtoNeg = false
335	m.serverName = ""
336	m.ocspStapling = false
337	m.ticketSupported = false
338	m.sessionTicket = nil
339	m.signatureAndHashes = nil
340	m.alpnProtocols = nil
341	m.scts = false
342
343	if len(data) == 0 {
344		// ClientHello is optionally followed by extension data
345		return true
346	}
347	if len(data) < 2 {
348		return false
349	}
350
351	extensionsLength := int(data[0])<<8 | int(data[1])
352	data = data[2:]
353	if extensionsLength != len(data) {
354		return false
355	}
356
357	for len(data) != 0 {
358		if len(data) < 4 {
359			return false
360		}
361		extension := uint16(data[0])<<8 | uint16(data[1])
362		length := int(data[2])<<8 | int(data[3])
363		data = data[4:]
364		if len(data) < length {
365			return false
366		}
367
368		switch extension {
369		case extensionServerName:
370			d := data[:length]
371			if len(d) < 2 {
372				return false
373			}
374			namesLen := int(d[0])<<8 | int(d[1])
375			d = d[2:]
376			if len(d) != namesLen {
377				return false
378			}
379			for len(d) > 0 {
380				if len(d) < 3 {
381					return false
382				}
383				nameType := d[0]
384				nameLen := int(d[1])<<8 | int(d[2])
385				d = d[3:]
386				if len(d) < nameLen {
387					return false
388				}
389				if nameType == 0 {
390					m.serverName = string(d[:nameLen])
391					break
392				}
393				d = d[nameLen:]
394			}
395		case extensionNextProtoNeg:
396			if length > 0 {
397				return false
398			}
399			m.nextProtoNeg = true
400		case extensionStatusRequest:
401			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
402		case extensionSupportedCurves:
403			// http://tools.ietf.org/html/rfc4492#section-5.5.1
404			if length < 2 {
405				return false
406			}
407			l := int(data[0])<<8 | int(data[1])
408			if l%2 == 1 || length != l+2 {
409				return false
410			}
411			numCurves := l / 2
412			m.supportedCurves = make([]CurveID, numCurves)
413			d := data[2:]
414			for i := 0; i < numCurves; i++ {
415				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
416				d = d[2:]
417			}
418		case extensionSupportedPoints:
419			// http://tools.ietf.org/html/rfc4492#section-5.5.2
420			if length < 1 {
421				return false
422			}
423			l := int(data[0])
424			if length != l+1 {
425				return false
426			}
427			m.supportedPoints = make([]uint8, l)
428			copy(m.supportedPoints, data[1:])
429		case extensionSessionTicket:
430			// http://tools.ietf.org/html/rfc5077#section-3.2
431			m.ticketSupported = true
432			m.sessionTicket = data[:length]
433		case extensionSignatureAlgorithms:
434			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
435			if length < 2 || length&1 != 0 {
436				return false
437			}
438			l := int(data[0])<<8 | int(data[1])
439			if l != length-2 {
440				return false
441			}
442			n := l / 2
443			d := data[2:]
444			m.signatureAndHashes = make([]signatureAndHash, n)
445			for i := range m.signatureAndHashes {
446				m.signatureAndHashes[i].hash = d[0]
447				m.signatureAndHashes[i].signature = d[1]
448				d = d[2:]
449			}
450		case extensionRenegotiationInfo:
451			if length != 1 || data[0] != 0 {
452				return false
453			}
454			m.secureRenegotiation = true
455		case extensionALPN:
456			if length < 2 {
457				return false
458			}
459			l := int(data[0])<<8 | int(data[1])
460			if l != length-2 {
461				return false
462			}
463			d := data[2:length]
464			for len(d) != 0 {
465				stringLen := int(d[0])
466				d = d[1:]
467				if stringLen == 0 || stringLen > len(d) {
468					return false
469				}
470				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
471				d = d[stringLen:]
472			}
473		case extensionSCT:
474			m.scts = true
475			if length != 0 {
476				return false
477			}
478		}
479		data = data[length:]
480	}
481
482	return true
483}
484
485type serverHelloMsg struct {
486	raw                 []byte
487	vers                uint16
488	random              []byte
489	sessionId           []byte
490	cipherSuite         uint16
491	compressionMethod   uint8
492	nextProtoNeg        bool
493	nextProtos          []string
494	ocspStapling        bool
495	scts                [][]byte
496	ticketSupported     bool
497	secureRenegotiation bool
498	alpnProtocol        string
499}
500
501func (m *serverHelloMsg) equal(i interface{}) bool {
502	m1, ok := i.(*serverHelloMsg)
503	if !ok {
504		return false
505	}
506
507	if len(m.scts) != len(m1.scts) {
508		return false
509	}
510	for i, sct := range m.scts {
511		if !bytes.Equal(sct, m1.scts[i]) {
512			return false
513		}
514	}
515
516	return bytes.Equal(m.raw, m1.raw) &&
517		m.vers == m1.vers &&
518		bytes.Equal(m.random, m1.random) &&
519		bytes.Equal(m.sessionId, m1.sessionId) &&
520		m.cipherSuite == m1.cipherSuite &&
521		m.compressionMethod == m1.compressionMethod &&
522		m.nextProtoNeg == m1.nextProtoNeg &&
523		eqStrings(m.nextProtos, m1.nextProtos) &&
524		m.ocspStapling == m1.ocspStapling &&
525		m.ticketSupported == m1.ticketSupported &&
526		m.secureRenegotiation == m1.secureRenegotiation &&
527		m.alpnProtocol == m1.alpnProtocol
528}
529
530func (m *serverHelloMsg) marshal() []byte {
531	if m.raw != nil {
532		return m.raw
533	}
534
535	length := 38 + len(m.sessionId)
536	numExtensions := 0
537	extensionsLength := 0
538
539	nextProtoLen := 0
540	if m.nextProtoNeg {
541		numExtensions++
542		for _, v := range m.nextProtos {
543			nextProtoLen += len(v)
544		}
545		nextProtoLen += len(m.nextProtos)
546		extensionsLength += nextProtoLen
547	}
548	if m.ocspStapling {
549		numExtensions++
550	}
551	if m.ticketSupported {
552		numExtensions++
553	}
554	if m.secureRenegotiation {
555		extensionsLength += 1
556		numExtensions++
557	}
558	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
559		if alpnLen >= 256 {
560			panic("invalid ALPN protocol")
561		}
562		extensionsLength += 2 + 1 + alpnLen
563		numExtensions++
564	}
565	sctLen := 0
566	if len(m.scts) > 0 {
567		for _, sct := range m.scts {
568			sctLen += len(sct) + 2
569		}
570		extensionsLength += 2 + sctLen
571		numExtensions++
572	}
573
574	if numExtensions > 0 {
575		extensionsLength += 4 * numExtensions
576		length += 2 + extensionsLength
577	}
578
579	x := make([]byte, 4+length)
580	x[0] = typeServerHello
581	x[1] = uint8(length >> 16)
582	x[2] = uint8(length >> 8)
583	x[3] = uint8(length)
584	x[4] = uint8(m.vers >> 8)
585	x[5] = uint8(m.vers)
586	copy(x[6:38], m.random)
587	x[38] = uint8(len(m.sessionId))
588	copy(x[39:39+len(m.sessionId)], m.sessionId)
589	z := x[39+len(m.sessionId):]
590	z[0] = uint8(m.cipherSuite >> 8)
591	z[1] = uint8(m.cipherSuite)
592	z[2] = uint8(m.compressionMethod)
593
594	z = z[3:]
595	if numExtensions > 0 {
596		z[0] = byte(extensionsLength >> 8)
597		z[1] = byte(extensionsLength)
598		z = z[2:]
599	}
600	if m.nextProtoNeg {
601		z[0] = byte(extensionNextProtoNeg >> 8)
602		z[1] = byte(extensionNextProtoNeg & 0xff)
603		z[2] = byte(nextProtoLen >> 8)
604		z[3] = byte(nextProtoLen)
605		z = z[4:]
606
607		for _, v := range m.nextProtos {
608			l := len(v)
609			if l > 255 {
610				l = 255
611			}
612			z[0] = byte(l)
613			copy(z[1:], []byte(v[0:l]))
614			z = z[1+l:]
615		}
616	}
617	if m.ocspStapling {
618		z[0] = byte(extensionStatusRequest >> 8)
619		z[1] = byte(extensionStatusRequest)
620		z = z[4:]
621	}
622	if m.ticketSupported {
623		z[0] = byte(extensionSessionTicket >> 8)
624		z[1] = byte(extensionSessionTicket)
625		z = z[4:]
626	}
627	if m.secureRenegotiation {
628		z[0] = byte(extensionRenegotiationInfo >> 8)
629		z[1] = byte(extensionRenegotiationInfo & 0xff)
630		z[2] = 0
631		z[3] = 1
632		z = z[5:]
633	}
634	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
635		z[0] = byte(extensionALPN >> 8)
636		z[1] = byte(extensionALPN & 0xff)
637		l := 2 + 1 + alpnLen
638		z[2] = byte(l >> 8)
639		z[3] = byte(l)
640		l -= 2
641		z[4] = byte(l >> 8)
642		z[5] = byte(l)
643		l -= 1
644		z[6] = byte(l)
645		copy(z[7:], []byte(m.alpnProtocol))
646		z = z[7+alpnLen:]
647	}
648	if sctLen > 0 {
649		z[0] = byte(extensionSCT >> 8)
650		z[1] = byte(extensionSCT)
651		l := sctLen + 2
652		z[2] = byte(l >> 8)
653		z[3] = byte(l)
654		z[4] = byte(sctLen >> 8)
655		z[5] = byte(sctLen)
656
657		z = z[6:]
658		for _, sct := range m.scts {
659			z[0] = byte(len(sct) >> 8)
660			z[1] = byte(len(sct))
661			copy(z[2:], sct)
662			z = z[len(sct)+2:]
663		}
664	}
665
666	m.raw = x
667
668	return x
669}
670
671func (m *serverHelloMsg) unmarshal(data []byte) bool {
672	if len(data) < 42 {
673		return false
674	}
675	m.raw = data
676	m.vers = uint16(data[4])<<8 | uint16(data[5])
677	m.random = data[6:38]
678	sessionIdLen := int(data[38])
679	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
680		return false
681	}
682	m.sessionId = data[39 : 39+sessionIdLen]
683	data = data[39+sessionIdLen:]
684	if len(data) < 3 {
685		return false
686	}
687	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
688	m.compressionMethod = data[2]
689	data = data[3:]
690
691	m.nextProtoNeg = false
692	m.nextProtos = nil
693	m.ocspStapling = false
694	m.scts = nil
695	m.ticketSupported = false
696	m.alpnProtocol = ""
697
698	if len(data) == 0 {
699		// ServerHello is optionally followed by extension data
700		return true
701	}
702	if len(data) < 2 {
703		return false
704	}
705
706	extensionsLength := int(data[0])<<8 | int(data[1])
707	data = data[2:]
708	if len(data) != extensionsLength {
709		return false
710	}
711
712	for len(data) != 0 {
713		if len(data) < 4 {
714			return false
715		}
716		extension := uint16(data[0])<<8 | uint16(data[1])
717		length := int(data[2])<<8 | int(data[3])
718		data = data[4:]
719		if len(data) < length {
720			return false
721		}
722
723		switch extension {
724		case extensionNextProtoNeg:
725			m.nextProtoNeg = true
726			d := data[:length]
727			for len(d) > 0 {
728				l := int(d[0])
729				d = d[1:]
730				if l == 0 || l > len(d) {
731					return false
732				}
733				m.nextProtos = append(m.nextProtos, string(d[:l]))
734				d = d[l:]
735			}
736		case extensionStatusRequest:
737			if length > 0 {
738				return false
739			}
740			m.ocspStapling = true
741		case extensionSessionTicket:
742			if length > 0 {
743				return false
744			}
745			m.ticketSupported = true
746		case extensionRenegotiationInfo:
747			if length != 1 || data[0] != 0 {
748				return false
749			}
750			m.secureRenegotiation = true
751		case extensionALPN:
752			d := data[:length]
753			if len(d) < 3 {
754				return false
755			}
756			l := int(d[0])<<8 | int(d[1])
757			if l != len(d)-2 {
758				return false
759			}
760			d = d[2:]
761			l = int(d[0])
762			if l != len(d)-1 {
763				return false
764			}
765			d = d[1:]
766			m.alpnProtocol = string(d)
767		case extensionSCT:
768			d := data[:length]
769
770			if len(d) < 2 {
771				return false
772			}
773			l := int(d[0])<<8 | int(d[1])
774			d = d[2:]
775			if len(d) != l {
776				return false
777			}
778			if l == 0 {
779				continue
780			}
781
782			m.scts = make([][]byte, 0, 3)
783			for len(d) != 0 {
784				if len(d) < 2 {
785					return false
786				}
787				sctLen := int(d[0])<<8 | int(d[1])
788				d = d[2:]
789				if len(d) < sctLen {
790					return false
791				}
792				m.scts = append(m.scts, d[:sctLen])
793				d = d[sctLen:]
794			}
795		}
796		data = data[length:]
797	}
798
799	return true
800}
801
802type certificateMsg struct {
803	raw          []byte
804	certificates [][]byte
805}
806
807func (m *certificateMsg) equal(i interface{}) bool {
808	m1, ok := i.(*certificateMsg)
809	if !ok {
810		return false
811	}
812
813	return bytes.Equal(m.raw, m1.raw) &&
814		eqByteSlices(m.certificates, m1.certificates)
815}
816
817func (m *certificateMsg) marshal() (x []byte) {
818	if m.raw != nil {
819		return m.raw
820	}
821
822	var i int
823	for _, slice := range m.certificates {
824		i += len(slice)
825	}
826
827	length := 3 + 3*len(m.certificates) + i
828	x = make([]byte, 4+length)
829	x[0] = typeCertificate
830	x[1] = uint8(length >> 16)
831	x[2] = uint8(length >> 8)
832	x[3] = uint8(length)
833
834	certificateOctets := length - 3
835	x[4] = uint8(certificateOctets >> 16)
836	x[5] = uint8(certificateOctets >> 8)
837	x[6] = uint8(certificateOctets)
838
839	y := x[7:]
840	for _, slice := range m.certificates {
841		y[0] = uint8(len(slice) >> 16)
842		y[1] = uint8(len(slice) >> 8)
843		y[2] = uint8(len(slice))
844		copy(y[3:], slice)
845		y = y[3+len(slice):]
846	}
847
848	m.raw = x
849	return
850}
851
852func (m *certificateMsg) unmarshal(data []byte) bool {
853	if len(data) < 7 {
854		return false
855	}
856
857	m.raw = data
858	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
859	if uint32(len(data)) != certsLen+7 {
860		return false
861	}
862
863	numCerts := 0
864	d := data[7:]
865	for certsLen > 0 {
866		if len(d) < 4 {
867			return false
868		}
869		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
870		if uint32(len(d)) < 3+certLen {
871			return false
872		}
873		d = d[3+certLen:]
874		certsLen -= 3 + certLen
875		numCerts++
876	}
877
878	m.certificates = make([][]byte, numCerts)
879	d = data[7:]
880	for i := 0; i < numCerts; i++ {
881		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
882		m.certificates[i] = d[3 : 3+certLen]
883		d = d[3+certLen:]
884	}
885
886	return true
887}
888
889type serverKeyExchangeMsg struct {
890	raw []byte
891	key []byte
892}
893
894func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
895	m1, ok := i.(*serverKeyExchangeMsg)
896	if !ok {
897		return false
898	}
899
900	return bytes.Equal(m.raw, m1.raw) &&
901		bytes.Equal(m.key, m1.key)
902}
903
904func (m *serverKeyExchangeMsg) marshal() []byte {
905	if m.raw != nil {
906		return m.raw
907	}
908	length := len(m.key)
909	x := make([]byte, length+4)
910	x[0] = typeServerKeyExchange
911	x[1] = uint8(length >> 16)
912	x[2] = uint8(length >> 8)
913	x[3] = uint8(length)
914	copy(x[4:], m.key)
915
916	m.raw = x
917	return x
918}
919
920func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
921	m.raw = data
922	if len(data) < 4 {
923		return false
924	}
925	m.key = data[4:]
926	return true
927}
928
929type certificateStatusMsg struct {
930	raw        []byte
931	statusType uint8
932	response   []byte
933}
934
935func (m *certificateStatusMsg) equal(i interface{}) bool {
936	m1, ok := i.(*certificateStatusMsg)
937	if !ok {
938		return false
939	}
940
941	return bytes.Equal(m.raw, m1.raw) &&
942		m.statusType == m1.statusType &&
943		bytes.Equal(m.response, m1.response)
944}
945
946func (m *certificateStatusMsg) marshal() []byte {
947	if m.raw != nil {
948		return m.raw
949	}
950
951	var x []byte
952	if m.statusType == statusTypeOCSP {
953		x = make([]byte, 4+4+len(m.response))
954		x[0] = typeCertificateStatus
955		l := len(m.response) + 4
956		x[1] = byte(l >> 16)
957		x[2] = byte(l >> 8)
958		x[3] = byte(l)
959		x[4] = statusTypeOCSP
960
961		l -= 4
962		x[5] = byte(l >> 16)
963		x[6] = byte(l >> 8)
964		x[7] = byte(l)
965		copy(x[8:], m.response)
966	} else {
967		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
968	}
969
970	m.raw = x
971	return x
972}
973
974func (m *certificateStatusMsg) unmarshal(data []byte) bool {
975	m.raw = data
976	if len(data) < 5 {
977		return false
978	}
979	m.statusType = data[4]
980
981	m.response = nil
982	if m.statusType == statusTypeOCSP {
983		if len(data) < 8 {
984			return false
985		}
986		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
987		if uint32(len(data)) != 4+4+respLen {
988			return false
989		}
990		m.response = data[8:]
991	}
992	return true
993}
994
995type serverHelloDoneMsg struct{}
996
997func (m *serverHelloDoneMsg) equal(i interface{}) bool {
998	_, ok := i.(*serverHelloDoneMsg)
999	return ok
1000}
1001
1002func (m *serverHelloDoneMsg) marshal() []byte {
1003	x := make([]byte, 4)
1004	x[0] = typeServerHelloDone
1005	return x
1006}
1007
1008func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1009	return len(data) == 4
1010}
1011
1012type clientKeyExchangeMsg struct {
1013	raw        []byte
1014	ciphertext []byte
1015}
1016
1017func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1018	m1, ok := i.(*clientKeyExchangeMsg)
1019	if !ok {
1020		return false
1021	}
1022
1023	return bytes.Equal(m.raw, m1.raw) &&
1024		bytes.Equal(m.ciphertext, m1.ciphertext)
1025}
1026
1027func (m *clientKeyExchangeMsg) marshal() []byte {
1028	if m.raw != nil {
1029		return m.raw
1030	}
1031	length := len(m.ciphertext)
1032	x := make([]byte, length+4)
1033	x[0] = typeClientKeyExchange
1034	x[1] = uint8(length >> 16)
1035	x[2] = uint8(length >> 8)
1036	x[3] = uint8(length)
1037	copy(x[4:], m.ciphertext)
1038
1039	m.raw = x
1040	return x
1041}
1042
1043func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1044	m.raw = data
1045	if len(data) < 4 {
1046		return false
1047	}
1048	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1049	if l != len(data)-4 {
1050		return false
1051	}
1052	m.ciphertext = data[4:]
1053	return true
1054}
1055
1056type finishedMsg struct {
1057	raw        []byte
1058	verifyData []byte
1059}
1060
1061func (m *finishedMsg) equal(i interface{}) bool {
1062	m1, ok := i.(*finishedMsg)
1063	if !ok {
1064		return false
1065	}
1066
1067	return bytes.Equal(m.raw, m1.raw) &&
1068		bytes.Equal(m.verifyData, m1.verifyData)
1069}
1070
1071func (m *finishedMsg) marshal() (x []byte) {
1072	if m.raw != nil {
1073		return m.raw
1074	}
1075
1076	x = make([]byte, 4+len(m.verifyData))
1077	x[0] = typeFinished
1078	x[3] = byte(len(m.verifyData))
1079	copy(x[4:], m.verifyData)
1080	m.raw = x
1081	return
1082}
1083
1084func (m *finishedMsg) unmarshal(data []byte) bool {
1085	m.raw = data
1086	if len(data) < 4 {
1087		return false
1088	}
1089	m.verifyData = data[4:]
1090	return true
1091}
1092
1093type nextProtoMsg struct {
1094	raw   []byte
1095	proto string
1096}
1097
1098func (m *nextProtoMsg) equal(i interface{}) bool {
1099	m1, ok := i.(*nextProtoMsg)
1100	if !ok {
1101		return false
1102	}
1103
1104	return bytes.Equal(m.raw, m1.raw) &&
1105		m.proto == m1.proto
1106}
1107
1108func (m *nextProtoMsg) marshal() []byte {
1109	if m.raw != nil {
1110		return m.raw
1111	}
1112	l := len(m.proto)
1113	if l > 255 {
1114		l = 255
1115	}
1116
1117	padding := 32 - (l+2)%32
1118	length := l + padding + 2
1119	x := make([]byte, length+4)
1120	x[0] = typeNextProtocol
1121	x[1] = uint8(length >> 16)
1122	x[2] = uint8(length >> 8)
1123	x[3] = uint8(length)
1124
1125	y := x[4:]
1126	y[0] = byte(l)
1127	copy(y[1:], []byte(m.proto[0:l]))
1128	y = y[1+l:]
1129	y[0] = byte(padding)
1130
1131	m.raw = x
1132
1133	return x
1134}
1135
1136func (m *nextProtoMsg) unmarshal(data []byte) bool {
1137	m.raw = data
1138
1139	if len(data) < 5 {
1140		return false
1141	}
1142	data = data[4:]
1143	protoLen := int(data[0])
1144	data = data[1:]
1145	if len(data) < protoLen {
1146		return false
1147	}
1148	m.proto = string(data[0:protoLen])
1149	data = data[protoLen:]
1150
1151	if len(data) < 1 {
1152		return false
1153	}
1154	paddingLen := int(data[0])
1155	data = data[1:]
1156	if len(data) != paddingLen {
1157		return false
1158	}
1159
1160	return true
1161}
1162
1163type certificateRequestMsg struct {
1164	raw []byte
1165	// hasSignatureAndHash indicates whether this message includes a list
1166	// of signature and hash functions. This change was introduced with TLS
1167	// 1.2.
1168	hasSignatureAndHash bool
1169
1170	certificateTypes       []byte
1171	signatureAndHashes     []signatureAndHash
1172	certificateAuthorities [][]byte
1173}
1174
1175func (m *certificateRequestMsg) equal(i interface{}) bool {
1176	m1, ok := i.(*certificateRequestMsg)
1177	if !ok {
1178		return false
1179	}
1180
1181	return bytes.Equal(m.raw, m1.raw) &&
1182		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1183		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1184		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
1185}
1186
1187func (m *certificateRequestMsg) marshal() (x []byte) {
1188	if m.raw != nil {
1189		return m.raw
1190	}
1191
1192	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
1193	length := 1 + len(m.certificateTypes) + 2
1194	casLength := 0
1195	for _, ca := range m.certificateAuthorities {
1196		casLength += 2 + len(ca)
1197	}
1198	length += casLength
1199
1200	if m.hasSignatureAndHash {
1201		length += 2 + 2*len(m.signatureAndHashes)
1202	}
1203
1204	x = make([]byte, 4+length)
1205	x[0] = typeCertificateRequest
1206	x[1] = uint8(length >> 16)
1207	x[2] = uint8(length >> 8)
1208	x[3] = uint8(length)
1209
1210	x[4] = uint8(len(m.certificateTypes))
1211
1212	copy(x[5:], m.certificateTypes)
1213	y := x[5+len(m.certificateTypes):]
1214
1215	if m.hasSignatureAndHash {
1216		n := len(m.signatureAndHashes) * 2
1217		y[0] = uint8(n >> 8)
1218		y[1] = uint8(n)
1219		y = y[2:]
1220		for _, sigAndHash := range m.signatureAndHashes {
1221			y[0] = sigAndHash.hash
1222			y[1] = sigAndHash.signature
1223			y = y[2:]
1224		}
1225	}
1226
1227	y[0] = uint8(casLength >> 8)
1228	y[1] = uint8(casLength)
1229	y = y[2:]
1230	for _, ca := range m.certificateAuthorities {
1231		y[0] = uint8(len(ca) >> 8)
1232		y[1] = uint8(len(ca))
1233		y = y[2:]
1234		copy(y, ca)
1235		y = y[len(ca):]
1236	}
1237
1238	m.raw = x
1239	return
1240}
1241
1242func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1243	m.raw = data
1244
1245	if len(data) < 5 {
1246		return false
1247	}
1248
1249	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1250	if uint32(len(data))-4 != length {
1251		return false
1252	}
1253
1254	numCertTypes := int(data[4])
1255	data = data[5:]
1256	if numCertTypes == 0 || len(data) <= numCertTypes {
1257		return false
1258	}
1259
1260	m.certificateTypes = make([]byte, numCertTypes)
1261	if copy(m.certificateTypes, data) != numCertTypes {
1262		return false
1263	}
1264
1265	data = data[numCertTypes:]
1266
1267	if m.hasSignatureAndHash {
1268		if len(data) < 2 {
1269			return false
1270		}
1271		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1272		data = data[2:]
1273		if sigAndHashLen&1 != 0 {
1274			return false
1275		}
1276		if len(data) < int(sigAndHashLen) {
1277			return false
1278		}
1279		numSigAndHash := sigAndHashLen / 2
1280		m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
1281		for i := range m.signatureAndHashes {
1282			m.signatureAndHashes[i].hash = data[0]
1283			m.signatureAndHashes[i].signature = data[1]
1284			data = data[2:]
1285		}
1286	}
1287
1288	if len(data) < 2 {
1289		return false
1290	}
1291	casLength := uint16(data[0])<<8 | uint16(data[1])
1292	data = data[2:]
1293	if len(data) < int(casLength) {
1294		return false
1295	}
1296	cas := make([]byte, casLength)
1297	copy(cas, data)
1298	data = data[casLength:]
1299
1300	m.certificateAuthorities = nil
1301	for len(cas) > 0 {
1302		if len(cas) < 2 {
1303			return false
1304		}
1305		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1306		cas = cas[2:]
1307
1308		if len(cas) < int(caLen) {
1309			return false
1310		}
1311
1312		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1313		cas = cas[caLen:]
1314	}
1315	if len(data) > 0 {
1316		return false
1317	}
1318
1319	return true
1320}
1321
1322type certificateVerifyMsg struct {
1323	raw                 []byte
1324	hasSignatureAndHash bool
1325	signatureAndHash    signatureAndHash
1326	signature           []byte
1327}
1328
1329func (m *certificateVerifyMsg) equal(i interface{}) bool {
1330	m1, ok := i.(*certificateVerifyMsg)
1331	if !ok {
1332		return false
1333	}
1334
1335	return bytes.Equal(m.raw, m1.raw) &&
1336		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1337		m.signatureAndHash.hash == m1.signatureAndHash.hash &&
1338		m.signatureAndHash.signature == m1.signatureAndHash.signature &&
1339		bytes.Equal(m.signature, m1.signature)
1340}
1341
1342func (m *certificateVerifyMsg) marshal() (x []byte) {
1343	if m.raw != nil {
1344		return m.raw
1345	}
1346
1347	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
1348	siglength := len(m.signature)
1349	length := 2 + siglength
1350	if m.hasSignatureAndHash {
1351		length += 2
1352	}
1353	x = make([]byte, 4+length)
1354	x[0] = typeCertificateVerify
1355	x[1] = uint8(length >> 16)
1356	x[2] = uint8(length >> 8)
1357	x[3] = uint8(length)
1358	y := x[4:]
1359	if m.hasSignatureAndHash {
1360		y[0] = m.signatureAndHash.hash
1361		y[1] = m.signatureAndHash.signature
1362		y = y[2:]
1363	}
1364	y[0] = uint8(siglength >> 8)
1365	y[1] = uint8(siglength)
1366	copy(y[2:], m.signature)
1367
1368	m.raw = x
1369
1370	return
1371}
1372
1373func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1374	m.raw = data
1375
1376	if len(data) < 6 {
1377		return false
1378	}
1379
1380	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1381	if uint32(len(data))-4 != length {
1382		return false
1383	}
1384
1385	data = data[4:]
1386	if m.hasSignatureAndHash {
1387		m.signatureAndHash.hash = data[0]
1388		m.signatureAndHash.signature = data[1]
1389		data = data[2:]
1390	}
1391
1392	if len(data) < 2 {
1393		return false
1394	}
1395	siglength := int(data[0])<<8 + int(data[1])
1396	data = data[2:]
1397	if len(data) != siglength {
1398		return false
1399	}
1400
1401	m.signature = data
1402
1403	return true
1404}
1405
1406type newSessionTicketMsg struct {
1407	raw    []byte
1408	ticket []byte
1409}
1410
1411func (m *newSessionTicketMsg) equal(i interface{}) bool {
1412	m1, ok := i.(*newSessionTicketMsg)
1413	if !ok {
1414		return false
1415	}
1416
1417	return bytes.Equal(m.raw, m1.raw) &&
1418		bytes.Equal(m.ticket, m1.ticket)
1419}
1420
1421func (m *newSessionTicketMsg) marshal() (x []byte) {
1422	if m.raw != nil {
1423		return m.raw
1424	}
1425
1426	// See http://tools.ietf.org/html/rfc5077#section-3.3
1427	ticketLen := len(m.ticket)
1428	length := 2 + 4 + ticketLen
1429	x = make([]byte, 4+length)
1430	x[0] = typeNewSessionTicket
1431	x[1] = uint8(length >> 16)
1432	x[2] = uint8(length >> 8)
1433	x[3] = uint8(length)
1434	x[8] = uint8(ticketLen >> 8)
1435	x[9] = uint8(ticketLen)
1436	copy(x[10:], m.ticket)
1437
1438	m.raw = x
1439
1440	return
1441}
1442
1443func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1444	m.raw = data
1445
1446	if len(data) < 10 {
1447		return false
1448	}
1449
1450	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1451	if uint32(len(data))-4 != length {
1452		return false
1453	}
1454
1455	ticketLen := int(data[8])<<8 + int(data[9])
1456	if len(data)-10 != ticketLen {
1457		return false
1458	}
1459
1460	m.ticket = data[10:]
1461
1462	return true
1463}
1464
1465func eqUint16s(x, y []uint16) bool {
1466	if len(x) != len(y) {
1467		return false
1468	}
1469	for i, v := range x {
1470		if y[i] != v {
1471			return false
1472		}
1473	}
1474	return true
1475}
1476
1477func eqCurveIDs(x, y []CurveID) bool {
1478	if len(x) != len(y) {
1479		return false
1480	}
1481	for i, v := range x {
1482		if y[i] != v {
1483			return false
1484		}
1485	}
1486	return true
1487}
1488
1489func eqStrings(x, y []string) bool {
1490	if len(x) != len(y) {
1491		return false
1492	}
1493	for i, v := range x {
1494		if y[i] != v {
1495			return false
1496		}
1497	}
1498	return true
1499}
1500
1501func eqByteSlices(x, y [][]byte) bool {
1502	if len(x) != len(y) {
1503		return false
1504	}
1505	for i, v := range x {
1506		if !bytes.Equal(v, y[i]) {
1507			return false
1508		}
1509	}
1510	return true
1511}
1512
1513func eqSignatureAndHashes(x, y []signatureAndHash) bool {
1514	if len(x) != len(y) {
1515		return false
1516	}
1517	for i, v := range x {
1518		v2 := y[i]
1519		if v.hash != v2.hash || v.signature != v2.signature {
1520			return false
1521		}
1522	}
1523	return true
1524}
1525