1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import "bytes"
8
9type clientHelloMsg struct {
10	raw                 []byte
11	vers                uint16
12	random              []byte
13	sessionId           []byte
14	cipherSuites        []uint16
15	compressionMethods  []uint8
16	nextProtoNeg        bool
17	serverName          string
18	ocspStapling        bool
19	scts                bool
20	supportedCurves     []CurveID
21	supportedPoints     []uint8
22	ticketSupported     bool
23	sessionTicket       []uint8
24	signatureAndHashes  []signatureAndHash
25	secureRenegotiation bool
26	alpnProtocols       []string
27}
28
29func (m *clientHelloMsg) equal(i interface{}) bool {
30	m1, ok := i.(*clientHelloMsg)
31	if !ok {
32		return false
33	}
34
35	return bytes.Equal(m.raw, m1.raw) &&
36		m.vers == m1.vers &&
37		bytes.Equal(m.random, m1.random) &&
38		bytes.Equal(m.sessionId, m1.sessionId) &&
39		eqUint16s(m.cipherSuites, m1.cipherSuites) &&
40		bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
41		m.nextProtoNeg == m1.nextProtoNeg &&
42		m.serverName == m1.serverName &&
43		m.ocspStapling == m1.ocspStapling &&
44		m.scts == m1.scts &&
45		eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
46		bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
47		m.ticketSupported == m1.ticketSupported &&
48		bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
49		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) &&
50		m.secureRenegotiation == m1.secureRenegotiation &&
51		eqStrings(m.alpnProtocols, m1.alpnProtocols)
52}
53
54func (m *clientHelloMsg) marshal() []byte {
55	if m.raw != nil {
56		return m.raw
57	}
58
59	length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
60	numExtensions := 0
61	extensionsLength := 0
62	if m.nextProtoNeg {
63		numExtensions++
64	}
65	if m.ocspStapling {
66		extensionsLength += 1 + 2 + 2
67		numExtensions++
68	}
69	if len(m.serverName) > 0 {
70		extensionsLength += 5 + len(m.serverName)
71		numExtensions++
72	}
73	if len(m.supportedCurves) > 0 {
74		extensionsLength += 2 + 2*len(m.supportedCurves)
75		numExtensions++
76	}
77	if len(m.supportedPoints) > 0 {
78		extensionsLength += 1 + len(m.supportedPoints)
79		numExtensions++
80	}
81	if m.ticketSupported {
82		extensionsLength += len(m.sessionTicket)
83		numExtensions++
84	}
85	if len(m.signatureAndHashes) > 0 {
86		extensionsLength += 2 + 2*len(m.signatureAndHashes)
87		numExtensions++
88	}
89	if m.secureRenegotiation {
90		extensionsLength += 1
91		numExtensions++
92	}
93	if len(m.alpnProtocols) > 0 {
94		extensionsLength += 2
95		for _, s := range m.alpnProtocols {
96			if l := len(s); l == 0 || l > 255 {
97				panic("invalid ALPN protocol")
98			}
99			extensionsLength++
100			extensionsLength += len(s)
101		}
102		numExtensions++
103	}
104	if m.scts {
105		numExtensions++
106	}
107	if numExtensions > 0 {
108		extensionsLength += 4 * numExtensions
109		length += 2 + extensionsLength
110	}
111
112	x := make([]byte, 4+length)
113	x[0] = typeClientHello
114	x[1] = uint8(length >> 16)
115	x[2] = uint8(length >> 8)
116	x[3] = uint8(length)
117	x[4] = uint8(m.vers >> 8)
118	x[5] = uint8(m.vers)
119	copy(x[6:38], m.random)
120	x[38] = uint8(len(m.sessionId))
121	copy(x[39:39+len(m.sessionId)], m.sessionId)
122	y := x[39+len(m.sessionId):]
123	y[0] = uint8(len(m.cipherSuites) >> 7)
124	y[1] = uint8(len(m.cipherSuites) << 1)
125	for i, suite := range m.cipherSuites {
126		y[2+i*2] = uint8(suite >> 8)
127		y[3+i*2] = uint8(suite)
128	}
129	z := y[2+len(m.cipherSuites)*2:]
130	z[0] = uint8(len(m.compressionMethods))
131	copy(z[1:], m.compressionMethods)
132
133	z = z[1+len(m.compressionMethods):]
134	if numExtensions > 0 {
135		z[0] = byte(extensionsLength >> 8)
136		z[1] = byte(extensionsLength)
137		z = z[2:]
138	}
139	if m.nextProtoNeg {
140		z[0] = byte(extensionNextProtoNeg >> 8)
141		z[1] = byte(extensionNextProtoNeg & 0xff)
142		// The length is always 0
143		z = z[4:]
144	}
145	if len(m.serverName) > 0 {
146		z[0] = byte(extensionServerName >> 8)
147		z[1] = byte(extensionServerName & 0xff)
148		l := len(m.serverName) + 5
149		z[2] = byte(l >> 8)
150		z[3] = byte(l)
151		z = z[4:]
152
153		// RFC 3546, section 3.1
154		//
155		// struct {
156		//     NameType name_type;
157		//     select (name_type) {
158		//         case host_name: HostName;
159		//     } name;
160		// } ServerName;
161		//
162		// enum {
163		//     host_name(0), (255)
164		// } NameType;
165		//
166		// opaque HostName<1..2^16-1>;
167		//
168		// struct {
169		//     ServerName server_name_list<1..2^16-1>
170		// } ServerNameList;
171
172		z[0] = byte((len(m.serverName) + 3) >> 8)
173		z[1] = byte(len(m.serverName) + 3)
174		z[3] = byte(len(m.serverName) >> 8)
175		z[4] = byte(len(m.serverName))
176		copy(z[5:], []byte(m.serverName))
177		z = z[l:]
178	}
179	if m.ocspStapling {
180		// RFC 4366, section 3.6
181		z[0] = byte(extensionStatusRequest >> 8)
182		z[1] = byte(extensionStatusRequest)
183		z[2] = 0
184		z[3] = 5
185		z[4] = 1 // OCSP type
186		// Two zero valued uint16s for the two lengths.
187		z = z[9:]
188	}
189	if len(m.supportedCurves) > 0 {
190		// http://tools.ietf.org/html/rfc4492#section-5.5.1
191		z[0] = byte(extensionSupportedCurves >> 8)
192		z[1] = byte(extensionSupportedCurves)
193		l := 2 + 2*len(m.supportedCurves)
194		z[2] = byte(l >> 8)
195		z[3] = byte(l)
196		l -= 2
197		z[4] = byte(l >> 8)
198		z[5] = byte(l)
199		z = z[6:]
200		for _, curve := range m.supportedCurves {
201			z[0] = byte(curve >> 8)
202			z[1] = byte(curve)
203			z = z[2:]
204		}
205	}
206	if len(m.supportedPoints) > 0 {
207		// http://tools.ietf.org/html/rfc4492#section-5.5.2
208		z[0] = byte(extensionSupportedPoints >> 8)
209		z[1] = byte(extensionSupportedPoints)
210		l := 1 + len(m.supportedPoints)
211		z[2] = byte(l >> 8)
212		z[3] = byte(l)
213		l--
214		z[4] = byte(l)
215		z = z[5:]
216		for _, pointFormat := range m.supportedPoints {
217			z[0] = byte(pointFormat)
218			z = z[1:]
219		}
220	}
221	if m.ticketSupported {
222		// http://tools.ietf.org/html/rfc5077#section-3.2
223		z[0] = byte(extensionSessionTicket >> 8)
224		z[1] = byte(extensionSessionTicket)
225		l := len(m.sessionTicket)
226		z[2] = byte(l >> 8)
227		z[3] = byte(l)
228		z = z[4:]
229		copy(z, m.sessionTicket)
230		z = z[len(m.sessionTicket):]
231	}
232	if len(m.signatureAndHashes) > 0 {
233		// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
234		z[0] = byte(extensionSignatureAlgorithms >> 8)
235		z[1] = byte(extensionSignatureAlgorithms)
236		l := 2 + 2*len(m.signatureAndHashes)
237		z[2] = byte(l >> 8)
238		z[3] = byte(l)
239		z = z[4:]
240
241		l -= 2
242		z[0] = byte(l >> 8)
243		z[1] = byte(l)
244		z = z[2:]
245		for _, sigAndHash := range m.signatureAndHashes {
246			z[0] = sigAndHash.hash
247			z[1] = sigAndHash.signature
248			z = z[2:]
249		}
250	}
251	if m.secureRenegotiation {
252		z[0] = byte(extensionRenegotiationInfo >> 8)
253		z[1] = byte(extensionRenegotiationInfo & 0xff)
254		z[2] = 0
255		z[3] = 1
256		z = z[5:]
257	}
258	if len(m.alpnProtocols) > 0 {
259		z[0] = byte(extensionALPN >> 8)
260		z[1] = byte(extensionALPN & 0xff)
261		lengths := z[2:]
262		z = z[6:]
263
264		stringsLength := 0
265		for _, s := range m.alpnProtocols {
266			l := len(s)
267			z[0] = byte(l)
268			copy(z[1:], s)
269			z = z[1+l:]
270			stringsLength += 1 + l
271		}
272
273		lengths[2] = byte(stringsLength >> 8)
274		lengths[3] = byte(stringsLength)
275		stringsLength += 2
276		lengths[0] = byte(stringsLength >> 8)
277		lengths[1] = byte(stringsLength)
278	}
279	if m.scts {
280		// https://tools.ietf.org/html/rfc6962#section-3.3.1
281		z[0] = byte(extensionSCT >> 8)
282		z[1] = byte(extensionSCT)
283		// zero uint16 for the zero-length extension_data
284		z = z[4:]
285	}
286
287	m.raw = x
288
289	return x
290}
291
292func (m *clientHelloMsg) unmarshal(data []byte) bool {
293	if len(data) < 42 {
294		return false
295	}
296	m.raw = data
297	m.vers = uint16(data[4])<<8 | uint16(data[5])
298	m.random = data[6:38]
299	sessionIdLen := int(data[38])
300	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
301		return false
302	}
303	m.sessionId = data[39 : 39+sessionIdLen]
304	data = data[39+sessionIdLen:]
305	if len(data) < 2 {
306		return false
307	}
308	// cipherSuiteLen is the number of bytes of cipher suite numbers. Since
309	// they are uint16s, the number must be even.
310	cipherSuiteLen := int(data[0])<<8 | int(data[1])
311	if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
312		return false
313	}
314	numCipherSuites := cipherSuiteLen / 2
315	m.cipherSuites = make([]uint16, numCipherSuites)
316	for i := 0; i < numCipherSuites; i++ {
317		m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
318		if m.cipherSuites[i] == scsvRenegotiation {
319			m.secureRenegotiation = true
320		}
321	}
322	data = data[2+cipherSuiteLen:]
323	if len(data) < 1 {
324		return false
325	}
326	compressionMethodsLen := int(data[0])
327	if len(data) < 1+compressionMethodsLen {
328		return false
329	}
330	m.compressionMethods = data[1 : 1+compressionMethodsLen]
331
332	data = data[1+compressionMethodsLen:]
333
334	m.nextProtoNeg = false
335	m.serverName = ""
336	m.ocspStapling = false
337	m.ticketSupported = false
338	m.sessionTicket = nil
339	m.signatureAndHashes = nil
340	m.alpnProtocols = nil
341	m.scts = false
342
343	if len(data) == 0 {
344		// ClientHello is optionally followed by extension data
345		return true
346	}
347	if len(data) < 2 {
348		return false
349	}
350
351	extensionsLength := int(data[0])<<8 | int(data[1])
352	data = data[2:]
353	if extensionsLength != len(data) {
354		return false
355	}
356
357	for len(data) != 0 {
358		if len(data) < 4 {
359			return false
360		}
361		extension := uint16(data[0])<<8 | uint16(data[1])
362		length := int(data[2])<<8 | int(data[3])
363		data = data[4:]
364		if len(data) < length {
365			return false
366		}
367
368		switch extension {
369		case extensionServerName:
370			d := data[:length]
371			if len(d) < 2 {
372				return false
373			}
374			namesLen := int(d[0])<<8 | int(d[1])
375			d = d[2:]
376			if len(d) != namesLen {
377				return false
378			}
379			for len(d) > 0 {
380				if len(d) < 3 {
381					return false
382				}
383				nameType := d[0]
384				nameLen := int(d[1])<<8 | int(d[2])
385				d = d[3:]
386				if len(d) < nameLen {
387					return false
388				}
389				if nameType == 0 {
390					m.serverName = string(d[:nameLen])
391					break
392				}
393				d = d[nameLen:]
394			}
395		case extensionNextProtoNeg:
396			if length > 0 {
397				return false
398			}
399			m.nextProtoNeg = true
400		case extensionStatusRequest:
401			m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
402		case extensionSupportedCurves:
403			// http://tools.ietf.org/html/rfc4492#section-5.5.1
404			if length < 2 {
405				return false
406			}
407			l := int(data[0])<<8 | int(data[1])
408			if l%2 == 1 || length != l+2 {
409				return false
410			}
411			numCurves := l / 2
412			m.supportedCurves = make([]CurveID, numCurves)
413			d := data[2:]
414			for i := 0; i < numCurves; i++ {
415				m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
416				d = d[2:]
417			}
418		case extensionSupportedPoints:
419			// http://tools.ietf.org/html/rfc4492#section-5.5.2
420			if length < 1 {
421				return false
422			}
423			l := int(data[0])
424			if length != l+1 {
425				return false
426			}
427			m.supportedPoints = make([]uint8, l)
428			copy(m.supportedPoints, data[1:])
429		case extensionSessionTicket:
430			// http://tools.ietf.org/html/rfc5077#section-3.2
431			m.ticketSupported = true
432			m.sessionTicket = data[:length]
433		case extensionSignatureAlgorithms:
434			// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
435			if length < 2 || length&1 != 0 {
436				return false
437			}
438			l := int(data[0])<<8 | int(data[1])
439			if l != length-2 {
440				return false
441			}
442			n := l / 2
443			d := data[2:]
444			m.signatureAndHashes = make([]signatureAndHash, n)
445			for i := range m.signatureAndHashes {
446				m.signatureAndHashes[i].hash = d[0]
447				m.signatureAndHashes[i].signature = d[1]
448				d = d[2:]
449			}
450		case extensionRenegotiationInfo:
451			if length != 1 || data[0] != 0 {
452				return false
453			}
454			m.secureRenegotiation = true
455		case extensionALPN:
456			if length < 2 {
457				return false
458			}
459			l := int(data[0])<<8 | int(data[1])
460			if l != length-2 {
461				return false
462			}
463			d := data[2:length]
464			for len(d) != 0 {
465				stringLen := int(d[0])
466				d = d[1:]
467				if stringLen == 0 || stringLen > len(d) {
468					return false
469				}
470				m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
471				d = d[stringLen:]
472			}
473		case extensionSCT:
474			m.scts = true
475			if length != 0 {
476				return false
477			}
478		}
479		data = data[length:]
480	}
481
482	return true
483}
484
485type serverHelloMsg struct {
486	raw                 []byte
487	vers                uint16
488	random              []byte
489	sessionId           []byte
490	cipherSuite         uint16
491	compressionMethod   uint8
492	nextProtoNeg        bool
493	nextProtos          []string
494	ocspStapling        bool
495	scts                [][]byte
496	ticketSupported     bool
497	secureRenegotiation bool
498	alpnProtocol        string
499}
500
501func (m *serverHelloMsg) equal(i interface{}) bool {
502	m1, ok := i.(*serverHelloMsg)
503	if !ok {
504		return false
505	}
506
507	if len(m.scts) != len(m1.scts) {
508		return false
509	}
510	for i, sct := range m.scts {
511		if !bytes.Equal(sct, m1.scts[i]) {
512			return false
513		}
514	}
515
516	return bytes.Equal(m.raw, m1.raw) &&
517		m.vers == m1.vers &&
518		bytes.Equal(m.random, m1.random) &&
519		bytes.Equal(m.sessionId, m1.sessionId) &&
520		m.cipherSuite == m1.cipherSuite &&
521		m.compressionMethod == m1.compressionMethod &&
522		m.nextProtoNeg == m1.nextProtoNeg &&
523		eqStrings(m.nextProtos, m1.nextProtos) &&
524		m.ocspStapling == m1.ocspStapling &&
525		m.ticketSupported == m1.ticketSupported &&
526		m.secureRenegotiation == m1.secureRenegotiation &&
527		m.alpnProtocol == m1.alpnProtocol
528}
529
530func (m *serverHelloMsg) marshal() []byte {
531	if m.raw != nil {
532		return m.raw
533	}
534
535	length := 38 + len(m.sessionId)
536	numExtensions := 0
537	extensionsLength := 0
538
539	nextProtoLen := 0
540	if m.nextProtoNeg {
541		numExtensions++
542		for _, v := range m.nextProtos {
543			nextProtoLen += len(v)
544		}
545		nextProtoLen += len(m.nextProtos)
546		extensionsLength += nextProtoLen
547	}
548	if m.ocspStapling {
549		numExtensions++
550	}
551	if m.ticketSupported {
552		numExtensions++
553	}
554	if m.secureRenegotiation {
555		extensionsLength += 1
556		numExtensions++
557	}
558	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
559		if alpnLen >= 256 {
560			panic("invalid ALPN protocol")
561		}
562		extensionsLength += 2 + 1 + alpnLen
563		numExtensions++
564	}
565	sctLen := 0
566	if len(m.scts) > 0 {
567		for _, sct := range m.scts {
568			sctLen += len(sct) + 2
569		}
570		extensionsLength += 2 + sctLen
571		numExtensions++
572	}
573
574	if numExtensions > 0 {
575		extensionsLength += 4 * numExtensions
576		length += 2 + extensionsLength
577	}
578
579	x := make([]byte, 4+length)
580	x[0] = typeServerHello
581	x[1] = uint8(length >> 16)
582	x[2] = uint8(length >> 8)
583	x[3] = uint8(length)
584	x[4] = uint8(m.vers >> 8)
585	x[5] = uint8(m.vers)
586	copy(x[6:38], m.random)
587	x[38] = uint8(len(m.sessionId))
588	copy(x[39:39+len(m.sessionId)], m.sessionId)
589	z := x[39+len(m.sessionId):]
590	z[0] = uint8(m.cipherSuite >> 8)
591	z[1] = uint8(m.cipherSuite)
592	z[2] = uint8(m.compressionMethod)
593
594	z = z[3:]
595	if numExtensions > 0 {
596		z[0] = byte(extensionsLength >> 8)
597		z[1] = byte(extensionsLength)
598		z = z[2:]
599	}
600	if m.nextProtoNeg {
601		z[0] = byte(extensionNextProtoNeg >> 8)
602		z[1] = byte(extensionNextProtoNeg & 0xff)
603		z[2] = byte(nextProtoLen >> 8)
604		z[3] = byte(nextProtoLen)
605		z = z[4:]
606
607		for _, v := range m.nextProtos {
608			l := len(v)
609			if l > 255 {
610				l = 255
611			}
612			z[0] = byte(l)
613			copy(z[1:], []byte(v[0:l]))
614			z = z[1+l:]
615		}
616	}
617	if m.ocspStapling {
618		z[0] = byte(extensionStatusRequest >> 8)
619		z[1] = byte(extensionStatusRequest)
620		z = z[4:]
621	}
622	if m.ticketSupported {
623		z[0] = byte(extensionSessionTicket >> 8)
624		z[1] = byte(extensionSessionTicket)
625		z = z[4:]
626	}
627	if m.secureRenegotiation {
628		z[0] = byte(extensionRenegotiationInfo >> 8)
629		z[1] = byte(extensionRenegotiationInfo & 0xff)
630		z[2] = 0
631		z[3] = 1
632		z = z[5:]
633	}
634	if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
635		z[0] = byte(extensionALPN >> 8)
636		z[1] = byte(extensionALPN & 0xff)
637		l := 2 + 1 + alpnLen
638		z[2] = byte(l >> 8)
639		z[3] = byte(l)
640		l -= 2
641		z[4] = byte(l >> 8)
642		z[5] = byte(l)
643		l -= 1
644		z[6] = byte(l)
645		copy(z[7:], []byte(m.alpnProtocol))
646		z = z[7+alpnLen:]
647	}
648	if sctLen > 0 {
649		z[0] = byte(extensionSCT >> 8)
650		z[1] = byte(extensionSCT)
651		l := sctLen + 2
652		z[2] = byte(l >> 8)
653		z[3] = byte(l)
654		z[4] = byte(sctLen >> 8)
655		z[5] = byte(sctLen)
656
657		z = z[6:]
658		for _, sct := range m.scts {
659			z[0] = byte(len(sct) >> 8)
660			z[1] = byte(len(sct))
661			copy(z[2:], sct)
662			z = z[len(sct)+2:]
663		}
664	}
665
666	m.raw = x
667
668	return x
669}
670
671func (m *serverHelloMsg) unmarshal(data []byte) bool {
672	if len(data) < 42 {
673		return false
674	}
675	m.raw = data
676	m.vers = uint16(data[4])<<8 | uint16(data[5])
677	m.random = data[6:38]
678	sessionIdLen := int(data[38])
679	if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
680		return false
681	}
682	m.sessionId = data[39 : 39+sessionIdLen]
683	data = data[39+sessionIdLen:]
684	if len(data) < 3 {
685		return false
686	}
687	m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
688	m.compressionMethod = data[2]
689	data = data[3:]
690
691	m.nextProtoNeg = false
692	m.nextProtos = nil
693	m.ocspStapling = false
694	m.scts = nil
695	m.ticketSupported = false
696	m.alpnProtocol = ""
697
698	if len(data) == 0 {
699		// ServerHello is optionally followed by extension data
700		return true
701	}
702	if len(data) < 2 {
703		return false
704	}
705
706	extensionsLength := int(data[0])<<8 | int(data[1])
707	data = data[2:]
708	if len(data) != extensionsLength {
709		return false
710	}
711
712	for len(data) != 0 {
713		if len(data) < 4 {
714			return false
715		}
716		extension := uint16(data[0])<<8 | uint16(data[1])
717		length := int(data[2])<<8 | int(data[3])
718		data = data[4:]
719		if len(data) < length {
720			return false
721		}
722
723		switch extension {
724		case extensionNextProtoNeg:
725			m.nextProtoNeg = true
726			d := data[:length]
727			for len(d) > 0 {
728				l := int(d[0])
729				d = d[1:]
730				if l == 0 || l > len(d) {
731					return false
732				}
733				m.nextProtos = append(m.nextProtos, string(d[:l]))
734				d = d[l:]
735			}
736		case extensionStatusRequest:
737			if length > 0 {
738				return false
739			}
740			m.ocspStapling = true
741		case extensionSessionTicket:
742			if length > 0 {
743				return false
744			}
745			m.ticketSupported = true
746		case extensionRenegotiationInfo:
747			if length != 1 || data[0] != 0 {
748				return false
749			}
750			m.secureRenegotiation = true
751		case extensionALPN:
752			d := data[:length]
753			if len(d) < 3 {
754				return false
755			}
756			l := int(d[0])<<8 | int(d[1])
757			if l != len(d)-2 {
758				return false
759			}
760			d = d[2:]
761			l = int(d[0])
762			if l != len(d)-1 {
763				return false
764			}
765			d = d[1:]
766			if len(d) == 0 {
767				// ALPN protocols must not be empty.
768				return false
769			}
770			m.alpnProtocol = string(d)
771		case extensionSCT:
772			d := data[:length]
773
774			if len(d) < 2 {
775				return false
776			}
777			l := int(d[0])<<8 | int(d[1])
778			d = d[2:]
779			if len(d) != l {
780				return false
781			}
782			if l == 0 {
783				continue
784			}
785
786			m.scts = make([][]byte, 0, 3)
787			for len(d) != 0 {
788				if len(d) < 2 {
789					return false
790				}
791				sctLen := int(d[0])<<8 | int(d[1])
792				d = d[2:]
793				if len(d) < sctLen {
794					return false
795				}
796				m.scts = append(m.scts, d[:sctLen])
797				d = d[sctLen:]
798			}
799		}
800		data = data[length:]
801	}
802
803	return true
804}
805
806type certificateMsg struct {
807	raw          []byte
808	certificates [][]byte
809}
810
811func (m *certificateMsg) equal(i interface{}) bool {
812	m1, ok := i.(*certificateMsg)
813	if !ok {
814		return false
815	}
816
817	return bytes.Equal(m.raw, m1.raw) &&
818		eqByteSlices(m.certificates, m1.certificates)
819}
820
821func (m *certificateMsg) marshal() (x []byte) {
822	if m.raw != nil {
823		return m.raw
824	}
825
826	var i int
827	for _, slice := range m.certificates {
828		i += len(slice)
829	}
830
831	length := 3 + 3*len(m.certificates) + i
832	x = make([]byte, 4+length)
833	x[0] = typeCertificate
834	x[1] = uint8(length >> 16)
835	x[2] = uint8(length >> 8)
836	x[3] = uint8(length)
837
838	certificateOctets := length - 3
839	x[4] = uint8(certificateOctets >> 16)
840	x[5] = uint8(certificateOctets >> 8)
841	x[6] = uint8(certificateOctets)
842
843	y := x[7:]
844	for _, slice := range m.certificates {
845		y[0] = uint8(len(slice) >> 16)
846		y[1] = uint8(len(slice) >> 8)
847		y[2] = uint8(len(slice))
848		copy(y[3:], slice)
849		y = y[3+len(slice):]
850	}
851
852	m.raw = x
853	return
854}
855
856func (m *certificateMsg) unmarshal(data []byte) bool {
857	if len(data) < 7 {
858		return false
859	}
860
861	m.raw = data
862	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
863	if uint32(len(data)) != certsLen+7 {
864		return false
865	}
866
867	numCerts := 0
868	d := data[7:]
869	for certsLen > 0 {
870		if len(d) < 4 {
871			return false
872		}
873		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
874		if uint32(len(d)) < 3+certLen {
875			return false
876		}
877		d = d[3+certLen:]
878		certsLen -= 3 + certLen
879		numCerts++
880	}
881
882	m.certificates = make([][]byte, numCerts)
883	d = data[7:]
884	for i := 0; i < numCerts; i++ {
885		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
886		m.certificates[i] = d[3 : 3+certLen]
887		d = d[3+certLen:]
888	}
889
890	return true
891}
892
893type serverKeyExchangeMsg struct {
894	raw []byte
895	key []byte
896}
897
898func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
899	m1, ok := i.(*serverKeyExchangeMsg)
900	if !ok {
901		return false
902	}
903
904	return bytes.Equal(m.raw, m1.raw) &&
905		bytes.Equal(m.key, m1.key)
906}
907
908func (m *serverKeyExchangeMsg) marshal() []byte {
909	if m.raw != nil {
910		return m.raw
911	}
912	length := len(m.key)
913	x := make([]byte, length+4)
914	x[0] = typeServerKeyExchange
915	x[1] = uint8(length >> 16)
916	x[2] = uint8(length >> 8)
917	x[3] = uint8(length)
918	copy(x[4:], m.key)
919
920	m.raw = x
921	return x
922}
923
924func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
925	m.raw = data
926	if len(data) < 4 {
927		return false
928	}
929	m.key = data[4:]
930	return true
931}
932
933type certificateStatusMsg struct {
934	raw        []byte
935	statusType uint8
936	response   []byte
937}
938
939func (m *certificateStatusMsg) equal(i interface{}) bool {
940	m1, ok := i.(*certificateStatusMsg)
941	if !ok {
942		return false
943	}
944
945	return bytes.Equal(m.raw, m1.raw) &&
946		m.statusType == m1.statusType &&
947		bytes.Equal(m.response, m1.response)
948}
949
950func (m *certificateStatusMsg) marshal() []byte {
951	if m.raw != nil {
952		return m.raw
953	}
954
955	var x []byte
956	if m.statusType == statusTypeOCSP {
957		x = make([]byte, 4+4+len(m.response))
958		x[0] = typeCertificateStatus
959		l := len(m.response) + 4
960		x[1] = byte(l >> 16)
961		x[2] = byte(l >> 8)
962		x[3] = byte(l)
963		x[4] = statusTypeOCSP
964
965		l -= 4
966		x[5] = byte(l >> 16)
967		x[6] = byte(l >> 8)
968		x[7] = byte(l)
969		copy(x[8:], m.response)
970	} else {
971		x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
972	}
973
974	m.raw = x
975	return x
976}
977
978func (m *certificateStatusMsg) unmarshal(data []byte) bool {
979	m.raw = data
980	if len(data) < 5 {
981		return false
982	}
983	m.statusType = data[4]
984
985	m.response = nil
986	if m.statusType == statusTypeOCSP {
987		if len(data) < 8 {
988			return false
989		}
990		respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
991		if uint32(len(data)) != 4+4+respLen {
992			return false
993		}
994		m.response = data[8:]
995	}
996	return true
997}
998
999type serverHelloDoneMsg struct{}
1000
1001func (m *serverHelloDoneMsg) equal(i interface{}) bool {
1002	_, ok := i.(*serverHelloDoneMsg)
1003	return ok
1004}
1005
1006func (m *serverHelloDoneMsg) marshal() []byte {
1007	x := make([]byte, 4)
1008	x[0] = typeServerHelloDone
1009	return x
1010}
1011
1012func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1013	return len(data) == 4
1014}
1015
1016type clientKeyExchangeMsg struct {
1017	raw        []byte
1018	ciphertext []byte
1019}
1020
1021func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
1022	m1, ok := i.(*clientKeyExchangeMsg)
1023	if !ok {
1024		return false
1025	}
1026
1027	return bytes.Equal(m.raw, m1.raw) &&
1028		bytes.Equal(m.ciphertext, m1.ciphertext)
1029}
1030
1031func (m *clientKeyExchangeMsg) marshal() []byte {
1032	if m.raw != nil {
1033		return m.raw
1034	}
1035	length := len(m.ciphertext)
1036	x := make([]byte, length+4)
1037	x[0] = typeClientKeyExchange
1038	x[1] = uint8(length >> 16)
1039	x[2] = uint8(length >> 8)
1040	x[3] = uint8(length)
1041	copy(x[4:], m.ciphertext)
1042
1043	m.raw = x
1044	return x
1045}
1046
1047func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1048	m.raw = data
1049	if len(data) < 4 {
1050		return false
1051	}
1052	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1053	if l != len(data)-4 {
1054		return false
1055	}
1056	m.ciphertext = data[4:]
1057	return true
1058}
1059
1060type finishedMsg struct {
1061	raw        []byte
1062	verifyData []byte
1063}
1064
1065func (m *finishedMsg) equal(i interface{}) bool {
1066	m1, ok := i.(*finishedMsg)
1067	if !ok {
1068		return false
1069	}
1070
1071	return bytes.Equal(m.raw, m1.raw) &&
1072		bytes.Equal(m.verifyData, m1.verifyData)
1073}
1074
1075func (m *finishedMsg) marshal() (x []byte) {
1076	if m.raw != nil {
1077		return m.raw
1078	}
1079
1080	x = make([]byte, 4+len(m.verifyData))
1081	x[0] = typeFinished
1082	x[3] = byte(len(m.verifyData))
1083	copy(x[4:], m.verifyData)
1084	m.raw = x
1085	return
1086}
1087
1088func (m *finishedMsg) unmarshal(data []byte) bool {
1089	m.raw = data
1090	if len(data) < 4 {
1091		return false
1092	}
1093	m.verifyData = data[4:]
1094	return true
1095}
1096
1097type nextProtoMsg struct {
1098	raw   []byte
1099	proto string
1100}
1101
1102func (m *nextProtoMsg) equal(i interface{}) bool {
1103	m1, ok := i.(*nextProtoMsg)
1104	if !ok {
1105		return false
1106	}
1107
1108	return bytes.Equal(m.raw, m1.raw) &&
1109		m.proto == m1.proto
1110}
1111
1112func (m *nextProtoMsg) marshal() []byte {
1113	if m.raw != nil {
1114		return m.raw
1115	}
1116	l := len(m.proto)
1117	if l > 255 {
1118		l = 255
1119	}
1120
1121	padding := 32 - (l+2)%32
1122	length := l + padding + 2
1123	x := make([]byte, length+4)
1124	x[0] = typeNextProtocol
1125	x[1] = uint8(length >> 16)
1126	x[2] = uint8(length >> 8)
1127	x[3] = uint8(length)
1128
1129	y := x[4:]
1130	y[0] = byte(l)
1131	copy(y[1:], []byte(m.proto[0:l]))
1132	y = y[1+l:]
1133	y[0] = byte(padding)
1134
1135	m.raw = x
1136
1137	return x
1138}
1139
1140func (m *nextProtoMsg) unmarshal(data []byte) bool {
1141	m.raw = data
1142
1143	if len(data) < 5 {
1144		return false
1145	}
1146	data = data[4:]
1147	protoLen := int(data[0])
1148	data = data[1:]
1149	if len(data) < protoLen {
1150		return false
1151	}
1152	m.proto = string(data[0:protoLen])
1153	data = data[protoLen:]
1154
1155	if len(data) < 1 {
1156		return false
1157	}
1158	paddingLen := int(data[0])
1159	data = data[1:]
1160	if len(data) != paddingLen {
1161		return false
1162	}
1163
1164	return true
1165}
1166
1167type certificateRequestMsg struct {
1168	raw []byte
1169	// hasSignatureAndHash indicates whether this message includes a list
1170	// of signature and hash functions. This change was introduced with TLS
1171	// 1.2.
1172	hasSignatureAndHash bool
1173
1174	certificateTypes       []byte
1175	signatureAndHashes     []signatureAndHash
1176	certificateAuthorities [][]byte
1177}
1178
1179func (m *certificateRequestMsg) equal(i interface{}) bool {
1180	m1, ok := i.(*certificateRequestMsg)
1181	if !ok {
1182		return false
1183	}
1184
1185	return bytes.Equal(m.raw, m1.raw) &&
1186		bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
1187		eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
1188		eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes)
1189}
1190
1191func (m *certificateRequestMsg) marshal() (x []byte) {
1192	if m.raw != nil {
1193		return m.raw
1194	}
1195
1196	// See http://tools.ietf.org/html/rfc4346#section-7.4.4
1197	length := 1 + len(m.certificateTypes) + 2
1198	casLength := 0
1199	for _, ca := range m.certificateAuthorities {
1200		casLength += 2 + len(ca)
1201	}
1202	length += casLength
1203
1204	if m.hasSignatureAndHash {
1205		length += 2 + 2*len(m.signatureAndHashes)
1206	}
1207
1208	x = make([]byte, 4+length)
1209	x[0] = typeCertificateRequest
1210	x[1] = uint8(length >> 16)
1211	x[2] = uint8(length >> 8)
1212	x[3] = uint8(length)
1213
1214	x[4] = uint8(len(m.certificateTypes))
1215
1216	copy(x[5:], m.certificateTypes)
1217	y := x[5+len(m.certificateTypes):]
1218
1219	if m.hasSignatureAndHash {
1220		n := len(m.signatureAndHashes) * 2
1221		y[0] = uint8(n >> 8)
1222		y[1] = uint8(n)
1223		y = y[2:]
1224		for _, sigAndHash := range m.signatureAndHashes {
1225			y[0] = sigAndHash.hash
1226			y[1] = sigAndHash.signature
1227			y = y[2:]
1228		}
1229	}
1230
1231	y[0] = uint8(casLength >> 8)
1232	y[1] = uint8(casLength)
1233	y = y[2:]
1234	for _, ca := range m.certificateAuthorities {
1235		y[0] = uint8(len(ca) >> 8)
1236		y[1] = uint8(len(ca))
1237		y = y[2:]
1238		copy(y, ca)
1239		y = y[len(ca):]
1240	}
1241
1242	m.raw = x
1243	return
1244}
1245
1246func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1247	m.raw = data
1248
1249	if len(data) < 5 {
1250		return false
1251	}
1252
1253	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1254	if uint32(len(data))-4 != length {
1255		return false
1256	}
1257
1258	numCertTypes := int(data[4])
1259	data = data[5:]
1260	if numCertTypes == 0 || len(data) <= numCertTypes {
1261		return false
1262	}
1263
1264	m.certificateTypes = make([]byte, numCertTypes)
1265	if copy(m.certificateTypes, data) != numCertTypes {
1266		return false
1267	}
1268
1269	data = data[numCertTypes:]
1270
1271	if m.hasSignatureAndHash {
1272		if len(data) < 2 {
1273			return false
1274		}
1275		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1276		data = data[2:]
1277		if sigAndHashLen&1 != 0 {
1278			return false
1279		}
1280		if len(data) < int(sigAndHashLen) {
1281			return false
1282		}
1283		numSigAndHash := sigAndHashLen / 2
1284		m.signatureAndHashes = make([]signatureAndHash, numSigAndHash)
1285		for i := range m.signatureAndHashes {
1286			m.signatureAndHashes[i].hash = data[0]
1287			m.signatureAndHashes[i].signature = data[1]
1288			data = data[2:]
1289		}
1290	}
1291
1292	if len(data) < 2 {
1293		return false
1294	}
1295	casLength := uint16(data[0])<<8 | uint16(data[1])
1296	data = data[2:]
1297	if len(data) < int(casLength) {
1298		return false
1299	}
1300	cas := make([]byte, casLength)
1301	copy(cas, data)
1302	data = data[casLength:]
1303
1304	m.certificateAuthorities = nil
1305	for len(cas) > 0 {
1306		if len(cas) < 2 {
1307			return false
1308		}
1309		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1310		cas = cas[2:]
1311
1312		if len(cas) < int(caLen) {
1313			return false
1314		}
1315
1316		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1317		cas = cas[caLen:]
1318	}
1319	if len(data) > 0 {
1320		return false
1321	}
1322
1323	return true
1324}
1325
1326type certificateVerifyMsg struct {
1327	raw                 []byte
1328	hasSignatureAndHash bool
1329	signatureAndHash    signatureAndHash
1330	signature           []byte
1331}
1332
1333func (m *certificateVerifyMsg) equal(i interface{}) bool {
1334	m1, ok := i.(*certificateVerifyMsg)
1335	if !ok {
1336		return false
1337	}
1338
1339	return bytes.Equal(m.raw, m1.raw) &&
1340		m.hasSignatureAndHash == m1.hasSignatureAndHash &&
1341		m.signatureAndHash.hash == m1.signatureAndHash.hash &&
1342		m.signatureAndHash.signature == m1.signatureAndHash.signature &&
1343		bytes.Equal(m.signature, m1.signature)
1344}
1345
1346func (m *certificateVerifyMsg) marshal() (x []byte) {
1347	if m.raw != nil {
1348		return m.raw
1349	}
1350
1351	// See http://tools.ietf.org/html/rfc4346#section-7.4.8
1352	siglength := len(m.signature)
1353	length := 2 + siglength
1354	if m.hasSignatureAndHash {
1355		length += 2
1356	}
1357	x = make([]byte, 4+length)
1358	x[0] = typeCertificateVerify
1359	x[1] = uint8(length >> 16)
1360	x[2] = uint8(length >> 8)
1361	x[3] = uint8(length)
1362	y := x[4:]
1363	if m.hasSignatureAndHash {
1364		y[0] = m.signatureAndHash.hash
1365		y[1] = m.signatureAndHash.signature
1366		y = y[2:]
1367	}
1368	y[0] = uint8(siglength >> 8)
1369	y[1] = uint8(siglength)
1370	copy(y[2:], m.signature)
1371
1372	m.raw = x
1373
1374	return
1375}
1376
1377func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1378	m.raw = data
1379
1380	if len(data) < 6 {
1381		return false
1382	}
1383
1384	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1385	if uint32(len(data))-4 != length {
1386		return false
1387	}
1388
1389	data = data[4:]
1390	if m.hasSignatureAndHash {
1391		m.signatureAndHash.hash = data[0]
1392		m.signatureAndHash.signature = data[1]
1393		data = data[2:]
1394	}
1395
1396	if len(data) < 2 {
1397		return false
1398	}
1399	siglength := int(data[0])<<8 + int(data[1])
1400	data = data[2:]
1401	if len(data) != siglength {
1402		return false
1403	}
1404
1405	m.signature = data
1406
1407	return true
1408}
1409
1410type newSessionTicketMsg struct {
1411	raw    []byte
1412	ticket []byte
1413}
1414
1415func (m *newSessionTicketMsg) equal(i interface{}) bool {
1416	m1, ok := i.(*newSessionTicketMsg)
1417	if !ok {
1418		return false
1419	}
1420
1421	return bytes.Equal(m.raw, m1.raw) &&
1422		bytes.Equal(m.ticket, m1.ticket)
1423}
1424
1425func (m *newSessionTicketMsg) marshal() (x []byte) {
1426	if m.raw != nil {
1427		return m.raw
1428	}
1429
1430	// See http://tools.ietf.org/html/rfc5077#section-3.3
1431	ticketLen := len(m.ticket)
1432	length := 2 + 4 + ticketLen
1433	x = make([]byte, 4+length)
1434	x[0] = typeNewSessionTicket
1435	x[1] = uint8(length >> 16)
1436	x[2] = uint8(length >> 8)
1437	x[3] = uint8(length)
1438	x[8] = uint8(ticketLen >> 8)
1439	x[9] = uint8(ticketLen)
1440	copy(x[10:], m.ticket)
1441
1442	m.raw = x
1443
1444	return
1445}
1446
1447func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1448	m.raw = data
1449
1450	if len(data) < 10 {
1451		return false
1452	}
1453
1454	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1455	if uint32(len(data))-4 != length {
1456		return false
1457	}
1458
1459	ticketLen := int(data[8])<<8 + int(data[9])
1460	if len(data)-10 != ticketLen {
1461		return false
1462	}
1463
1464	m.ticket = data[10:]
1465
1466	return true
1467}
1468
1469func eqUint16s(x, y []uint16) bool {
1470	if len(x) != len(y) {
1471		return false
1472	}
1473	for i, v := range x {
1474		if y[i] != v {
1475			return false
1476		}
1477	}
1478	return true
1479}
1480
1481func eqCurveIDs(x, y []CurveID) bool {
1482	if len(x) != len(y) {
1483		return false
1484	}
1485	for i, v := range x {
1486		if y[i] != v {
1487			return false
1488		}
1489	}
1490	return true
1491}
1492
1493func eqStrings(x, y []string) bool {
1494	if len(x) != len(y) {
1495		return false
1496	}
1497	for i, v := range x {
1498		if y[i] != v {
1499			return false
1500		}
1501	}
1502	return true
1503}
1504
1505func eqByteSlices(x, y [][]byte) bool {
1506	if len(x) != len(y) {
1507		return false
1508	}
1509	for i, v := range x {
1510		if !bytes.Equal(v, y[i]) {
1511			return false
1512		}
1513	}
1514	return true
1515}
1516
1517func eqSignatureAndHashes(x, y []signatureAndHash) bool {
1518	if len(x) != len(y) {
1519		return false
1520	}
1521	for i, v := range x {
1522		v2 := y[i]
1523		if v.hash != v2.hash || v.signature != v2.signature {
1524			return false
1525		}
1526	}
1527	return true
1528}
1529