1// Copyright 2009 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"fmt"
9	"internal/x/crypto/cryptobyte"
10	"strings"
11)
12
13// The marshalingFunction type is an adapter to allow the use of ordinary
14// functions as cryptobyte.MarshalingValue.
15type marshalingFunction func(b *cryptobyte.Builder) error
16
17func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
18	return f(b)
19}
20
21// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
22// the length of the sequence is not the value specified, it produces an error.
23func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
24	b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
25		if len(v) != n {
26			return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
27		}
28		b.AddBytes(v)
29		return nil
30	}))
31}
32
33// addUint64 appends a big-endian, 64-bit value to the cryptobyte.Builder.
34func addUint64(b *cryptobyte.Builder, v uint64) {
35	b.AddUint32(uint32(v >> 32))
36	b.AddUint32(uint32(v))
37}
38
39// readUint64 decodes a big-endian, 64-bit value into out and advances over it.
40// It reports whether the read was successful.
41func readUint64(s *cryptobyte.String, out *uint64) bool {
42	var hi, lo uint32
43	if !s.ReadUint32(&hi) || !s.ReadUint32(&lo) {
44		return false
45	}
46	*out = uint64(hi)<<32 | uint64(lo)
47	return true
48}
49
50// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
51// []byte instead of a cryptobyte.String.
52func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
53	return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
54}
55
56// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
57// []byte instead of a cryptobyte.String.
58func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
59	return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
60}
61
62// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
63// []byte instead of a cryptobyte.String.
64func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
65	return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
66}
67
68type clientHelloMsg struct {
69	raw                              []byte
70	vers                             uint16
71	random                           []byte
72	sessionId                        []byte
73	cipherSuites                     []uint16
74	compressionMethods               []uint8
75	nextProtoNeg                     bool
76	serverName                       string
77	ocspStapling                     bool
78	supportedCurves                  []CurveID
79	supportedPoints                  []uint8
80	ticketSupported                  bool
81	sessionTicket                    []uint8
82	supportedSignatureAlgorithms     []SignatureScheme
83	supportedSignatureAlgorithmsCert []SignatureScheme
84	secureRenegotiationSupported     bool
85	secureRenegotiation              []byte
86	alpnProtocols                    []string
87	scts                             bool
88	supportedVersions                []uint16
89	cookie                           []byte
90	keyShares                        []keyShare
91	earlyData                        bool
92	pskModes                         []uint8
93	pskIdentities                    []pskIdentity
94	pskBinders                       [][]byte
95}
96
97func (m *clientHelloMsg) marshal() []byte {
98	if m.raw != nil {
99		return m.raw
100	}
101
102	var b cryptobyte.Builder
103	b.AddUint8(typeClientHello)
104	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
105		b.AddUint16(m.vers)
106		addBytesWithLength(b, m.random, 32)
107		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
108			b.AddBytes(m.sessionId)
109		})
110		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
111			for _, suite := range m.cipherSuites {
112				b.AddUint16(suite)
113			}
114		})
115		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
116			b.AddBytes(m.compressionMethods)
117		})
118
119		// If extensions aren't present, omit them.
120		var extensionsPresent bool
121		bWithoutExtensions := *b
122
123		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
124			if m.nextProtoNeg {
125				// draft-agl-tls-nextprotoneg-04
126				b.AddUint16(extensionNextProtoNeg)
127				b.AddUint16(0) // empty extension_data
128			}
129			if len(m.serverName) > 0 {
130				// RFC 6066, Section 3
131				b.AddUint16(extensionServerName)
132				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
133					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
134						b.AddUint8(0) // name_type = host_name
135						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
136							b.AddBytes([]byte(m.serverName))
137						})
138					})
139				})
140			}
141			if m.ocspStapling {
142				// RFC 4366, Section 3.6
143				b.AddUint16(extensionStatusRequest)
144				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
145					b.AddUint8(1)  // status_type = ocsp
146					b.AddUint16(0) // empty responder_id_list
147					b.AddUint16(0) // empty request_extensions
148				})
149			}
150			if len(m.supportedCurves) > 0 {
151				// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
152				b.AddUint16(extensionSupportedCurves)
153				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
154					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
155						for _, curve := range m.supportedCurves {
156							b.AddUint16(uint16(curve))
157						}
158					})
159				})
160			}
161			if len(m.supportedPoints) > 0 {
162				// RFC 4492, Section 5.1.2
163				b.AddUint16(extensionSupportedPoints)
164				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
165					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
166						b.AddBytes(m.supportedPoints)
167					})
168				})
169			}
170			if m.ticketSupported {
171				// RFC 5077, Section 3.2
172				b.AddUint16(extensionSessionTicket)
173				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
174					b.AddBytes(m.sessionTicket)
175				})
176			}
177			if len(m.supportedSignatureAlgorithms) > 0 {
178				// RFC 5246, Section 7.4.1.4.1
179				b.AddUint16(extensionSignatureAlgorithms)
180				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
181					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
182						for _, sigAlgo := range m.supportedSignatureAlgorithms {
183							b.AddUint16(uint16(sigAlgo))
184						}
185					})
186				})
187			}
188			if len(m.supportedSignatureAlgorithmsCert) > 0 {
189				// RFC 8446, Section 4.2.3
190				b.AddUint16(extensionSignatureAlgorithmsCert)
191				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
192					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
193						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
194							b.AddUint16(uint16(sigAlgo))
195						}
196					})
197				})
198			}
199			if m.secureRenegotiationSupported {
200				// RFC 5746, Section 3.2
201				b.AddUint16(extensionRenegotiationInfo)
202				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
203					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
204						b.AddBytes(m.secureRenegotiation)
205					})
206				})
207			}
208			if len(m.alpnProtocols) > 0 {
209				// RFC 7301, Section 3.1
210				b.AddUint16(extensionALPN)
211				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
212					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
213						for _, proto := range m.alpnProtocols {
214							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
215								b.AddBytes([]byte(proto))
216							})
217						}
218					})
219				})
220			}
221			if m.scts {
222				// RFC 6962, Section 3.3.1
223				b.AddUint16(extensionSCT)
224				b.AddUint16(0) // empty extension_data
225			}
226			if len(m.supportedVersions) > 0 {
227				// RFC 8446, Section 4.2.1
228				b.AddUint16(extensionSupportedVersions)
229				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
230					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
231						for _, vers := range m.supportedVersions {
232							b.AddUint16(vers)
233						}
234					})
235				})
236			}
237			if len(m.cookie) > 0 {
238				// RFC 8446, Section 4.2.2
239				b.AddUint16(extensionCookie)
240				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
241					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
242						b.AddBytes(m.cookie)
243					})
244				})
245			}
246			if len(m.keyShares) > 0 {
247				// RFC 8446, Section 4.2.8
248				b.AddUint16(extensionKeyShare)
249				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
250					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
251						for _, ks := range m.keyShares {
252							b.AddUint16(uint16(ks.group))
253							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
254								b.AddBytes(ks.data)
255							})
256						}
257					})
258				})
259			}
260			if m.earlyData {
261				// RFC 8446, Section 4.2.10
262				b.AddUint16(extensionEarlyData)
263				b.AddUint16(0) // empty extension_data
264			}
265			if len(m.pskModes) > 0 {
266				// RFC 8446, Section 4.2.9
267				b.AddUint16(extensionPSKModes)
268				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
269					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
270						b.AddBytes(m.pskModes)
271					})
272				})
273			}
274			if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
275				// RFC 8446, Section 4.2.11
276				b.AddUint16(extensionPreSharedKey)
277				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
278					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
279						for _, psk := range m.pskIdentities {
280							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
281								b.AddBytes(psk.label)
282							})
283							b.AddUint32(psk.obfuscatedTicketAge)
284						}
285					})
286					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
287						for _, binder := range m.pskBinders {
288							b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
289								b.AddBytes(binder)
290							})
291						}
292					})
293				})
294			}
295
296			extensionsPresent = len(b.BytesOrPanic()) > 2
297		})
298
299		if !extensionsPresent {
300			*b = bWithoutExtensions
301		}
302	})
303
304	m.raw = b.BytesOrPanic()
305	return m.raw
306}
307
308// marshalWithoutBinders returns the ClientHello through the
309// PreSharedKeyExtension.identities field, according to RFC 8446, Section
310// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
311func (m *clientHelloMsg) marshalWithoutBinders() []byte {
312	bindersLen := 2 // uint16 length prefix
313	for _, binder := range m.pskBinders {
314		bindersLen += 1 // uint8 length prefix
315		bindersLen += len(binder)
316	}
317
318	fullMessage := m.marshal()
319	return fullMessage[:len(fullMessage)-bindersLen]
320}
321
322// updateBinders updates the m.pskBinders field, if necessary updating the
323// cached marshalled representation. The supplied binders must have the same
324// length as the current m.pskBinders.
325func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
326	if len(pskBinders) != len(m.pskBinders) {
327		panic("tls: internal error: pskBinders length mismatch")
328	}
329	for i := range m.pskBinders {
330		if len(pskBinders[i]) != len(m.pskBinders[i]) {
331			panic("tls: internal error: pskBinders length mismatch")
332		}
333	}
334	m.pskBinders = pskBinders
335	if m.raw != nil {
336		lenWithoutBinders := len(m.marshalWithoutBinders())
337		// TODO(filippo): replace with NewFixedBuilder once CL 148882 is imported.
338		b := cryptobyte.NewBuilder(m.raw[:lenWithoutBinders])
339		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
340			for _, binder := range m.pskBinders {
341				b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
342					b.AddBytes(binder)
343				})
344			}
345		})
346		if len(b.BytesOrPanic()) != len(m.raw) {
347			panic("tls: internal error: failed to update binders")
348		}
349	}
350}
351
352func (m *clientHelloMsg) unmarshal(data []byte) bool {
353	*m = clientHelloMsg{raw: data}
354	s := cryptobyte.String(data)
355
356	if !s.Skip(4) || // message type and uint24 length field
357		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
358		!readUint8LengthPrefixed(&s, &m.sessionId) {
359		return false
360	}
361
362	var cipherSuites cryptobyte.String
363	if !s.ReadUint16LengthPrefixed(&cipherSuites) {
364		return false
365	}
366	m.cipherSuites = []uint16{}
367	m.secureRenegotiationSupported = false
368	for !cipherSuites.Empty() {
369		var suite uint16
370		if !cipherSuites.ReadUint16(&suite) {
371			return false
372		}
373		if suite == scsvRenegotiation {
374			m.secureRenegotiationSupported = true
375		}
376		m.cipherSuites = append(m.cipherSuites, suite)
377	}
378
379	if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
380		return false
381	}
382
383	if s.Empty() {
384		// ClientHello is optionally followed by extension data
385		return true
386	}
387
388	var extensions cryptobyte.String
389	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
390		return false
391	}
392
393	for !extensions.Empty() {
394		var extension uint16
395		var extData cryptobyte.String
396		if !extensions.ReadUint16(&extension) ||
397			!extensions.ReadUint16LengthPrefixed(&extData) {
398			return false
399		}
400
401		switch extension {
402		case extensionServerName:
403			// RFC 6066, Section 3
404			var nameList cryptobyte.String
405			if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
406				return false
407			}
408			for !nameList.Empty() {
409				var nameType uint8
410				var serverName cryptobyte.String
411				if !nameList.ReadUint8(&nameType) ||
412					!nameList.ReadUint16LengthPrefixed(&serverName) ||
413					serverName.Empty() {
414					return false
415				}
416				if nameType != 0 {
417					continue
418				}
419				if len(m.serverName) != 0 {
420					// Multiple names of the same name_type are prohibited.
421					return false
422				}
423				m.serverName = string(serverName)
424				// An SNI value may not include a trailing dot.
425				if strings.HasSuffix(m.serverName, ".") {
426					return false
427				}
428			}
429		case extensionNextProtoNeg:
430			// draft-agl-tls-nextprotoneg-04
431			m.nextProtoNeg = true
432		case extensionStatusRequest:
433			// RFC 4366, Section 3.6
434			var statusType uint8
435			var ignored cryptobyte.String
436			if !extData.ReadUint8(&statusType) ||
437				!extData.ReadUint16LengthPrefixed(&ignored) ||
438				!extData.ReadUint16LengthPrefixed(&ignored) {
439				return false
440			}
441			m.ocspStapling = statusType == statusTypeOCSP
442		case extensionSupportedCurves:
443			// RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
444			var curves cryptobyte.String
445			if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
446				return false
447			}
448			for !curves.Empty() {
449				var curve uint16
450				if !curves.ReadUint16(&curve) {
451					return false
452				}
453				m.supportedCurves = append(m.supportedCurves, CurveID(curve))
454			}
455		case extensionSupportedPoints:
456			// RFC 4492, Section 5.1.2
457			if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
458				len(m.supportedPoints) == 0 {
459				return false
460			}
461		case extensionSessionTicket:
462			// RFC 5077, Section 3.2
463			m.ticketSupported = true
464			extData.ReadBytes(&m.sessionTicket, len(extData))
465		case extensionSignatureAlgorithms:
466			// RFC 5246, Section 7.4.1.4.1
467			var sigAndAlgs cryptobyte.String
468			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
469				return false
470			}
471			for !sigAndAlgs.Empty() {
472				var sigAndAlg uint16
473				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
474					return false
475				}
476				m.supportedSignatureAlgorithms = append(
477					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
478			}
479		case extensionSignatureAlgorithmsCert:
480			// RFC 8446, Section 4.2.3
481			var sigAndAlgs cryptobyte.String
482			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
483				return false
484			}
485			for !sigAndAlgs.Empty() {
486				var sigAndAlg uint16
487				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
488					return false
489				}
490				m.supportedSignatureAlgorithmsCert = append(
491					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
492			}
493		case extensionRenegotiationInfo:
494			// RFC 5746, Section 3.2
495			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
496				return false
497			}
498			m.secureRenegotiationSupported = true
499		case extensionALPN:
500			// RFC 7301, Section 3.1
501			var protoList cryptobyte.String
502			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
503				return false
504			}
505			for !protoList.Empty() {
506				var proto cryptobyte.String
507				if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
508					return false
509				}
510				m.alpnProtocols = append(m.alpnProtocols, string(proto))
511			}
512		case extensionSCT:
513			// RFC 6962, Section 3.3.1
514			m.scts = true
515		case extensionSupportedVersions:
516			// RFC 8446, Section 4.2.1
517			var versList cryptobyte.String
518			if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
519				return false
520			}
521			for !versList.Empty() {
522				var vers uint16
523				if !versList.ReadUint16(&vers) {
524					return false
525				}
526				m.supportedVersions = append(m.supportedVersions, vers)
527			}
528		case extensionCookie:
529			// RFC 8446, Section 4.2.2
530			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
531				len(m.cookie) == 0 {
532				return false
533			}
534		case extensionKeyShare:
535			// RFC 8446, Section 4.2.8
536			var clientShares cryptobyte.String
537			if !extData.ReadUint16LengthPrefixed(&clientShares) {
538				return false
539			}
540			for !clientShares.Empty() {
541				var ks keyShare
542				if !clientShares.ReadUint16((*uint16)(&ks.group)) ||
543					!readUint16LengthPrefixed(&clientShares, &ks.data) ||
544					len(ks.data) == 0 {
545					return false
546				}
547				m.keyShares = append(m.keyShares, ks)
548			}
549		case extensionEarlyData:
550			// RFC 8446, Section 4.2.10
551			m.earlyData = true
552		case extensionPSKModes:
553			// RFC 8446, Section 4.2.9
554			if !readUint8LengthPrefixed(&extData, &m.pskModes) {
555				return false
556			}
557		case extensionPreSharedKey:
558			// RFC 8446, Section 4.2.11
559			if !extensions.Empty() {
560				return false // pre_shared_key must be the last extension
561			}
562			var identities cryptobyte.String
563			if !extData.ReadUint16LengthPrefixed(&identities) || identities.Empty() {
564				return false
565			}
566			for !identities.Empty() {
567				var psk pskIdentity
568				if !readUint16LengthPrefixed(&identities, &psk.label) ||
569					!identities.ReadUint32(&psk.obfuscatedTicketAge) ||
570					len(psk.label) == 0 {
571					return false
572				}
573				m.pskIdentities = append(m.pskIdentities, psk)
574			}
575			var binders cryptobyte.String
576			if !extData.ReadUint16LengthPrefixed(&binders) || binders.Empty() {
577				return false
578			}
579			for !binders.Empty() {
580				var binder []byte
581				if !readUint8LengthPrefixed(&binders, &binder) ||
582					len(binder) == 0 {
583					return false
584				}
585				m.pskBinders = append(m.pskBinders, binder)
586			}
587		default:
588			// Ignore unknown extensions.
589			continue
590		}
591
592		if !extData.Empty() {
593			return false
594		}
595	}
596
597	return true
598}
599
600type serverHelloMsg struct {
601	raw                          []byte
602	vers                         uint16
603	random                       []byte
604	sessionId                    []byte
605	cipherSuite                  uint16
606	compressionMethod            uint8
607	nextProtoNeg                 bool
608	nextProtos                   []string
609	ocspStapling                 bool
610	ticketSupported              bool
611	secureRenegotiationSupported bool
612	secureRenegotiation          []byte
613	alpnProtocol                 string
614	scts                         [][]byte
615	supportedVersion             uint16
616	serverShare                  keyShare
617	selectedIdentityPresent      bool
618	selectedIdentity             uint16
619
620	// HelloRetryRequest extensions
621	cookie        []byte
622	selectedGroup CurveID
623}
624
625func (m *serverHelloMsg) marshal() []byte {
626	if m.raw != nil {
627		return m.raw
628	}
629
630	var b cryptobyte.Builder
631	b.AddUint8(typeServerHello)
632	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
633		b.AddUint16(m.vers)
634		addBytesWithLength(b, m.random, 32)
635		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
636			b.AddBytes(m.sessionId)
637		})
638		b.AddUint16(m.cipherSuite)
639		b.AddUint8(m.compressionMethod)
640
641		// If extensions aren't present, omit them.
642		var extensionsPresent bool
643		bWithoutExtensions := *b
644
645		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
646			if m.nextProtoNeg {
647				b.AddUint16(extensionNextProtoNeg)
648				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
649					for _, proto := range m.nextProtos {
650						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
651							b.AddBytes([]byte(proto))
652						})
653					}
654				})
655			}
656			if m.ocspStapling {
657				b.AddUint16(extensionStatusRequest)
658				b.AddUint16(0) // empty extension_data
659			}
660			if m.ticketSupported {
661				b.AddUint16(extensionSessionTicket)
662				b.AddUint16(0) // empty extension_data
663			}
664			if m.secureRenegotiationSupported {
665				b.AddUint16(extensionRenegotiationInfo)
666				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
667					b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
668						b.AddBytes(m.secureRenegotiation)
669					})
670				})
671			}
672			if len(m.alpnProtocol) > 0 {
673				b.AddUint16(extensionALPN)
674				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
675					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
676						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
677							b.AddBytes([]byte(m.alpnProtocol))
678						})
679					})
680				})
681			}
682			if len(m.scts) > 0 {
683				b.AddUint16(extensionSCT)
684				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
685					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
686						for _, sct := range m.scts {
687							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
688								b.AddBytes(sct)
689							})
690						}
691					})
692				})
693			}
694			if m.supportedVersion != 0 {
695				b.AddUint16(extensionSupportedVersions)
696				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
697					b.AddUint16(m.supportedVersion)
698				})
699			}
700			if m.serverShare.group != 0 {
701				b.AddUint16(extensionKeyShare)
702				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
703					b.AddUint16(uint16(m.serverShare.group))
704					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
705						b.AddBytes(m.serverShare.data)
706					})
707				})
708			}
709			if m.selectedIdentityPresent {
710				b.AddUint16(extensionPreSharedKey)
711				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
712					b.AddUint16(m.selectedIdentity)
713				})
714			}
715
716			if len(m.cookie) > 0 {
717				b.AddUint16(extensionCookie)
718				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
719					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
720						b.AddBytes(m.cookie)
721					})
722				})
723			}
724			if m.selectedGroup != 0 {
725				b.AddUint16(extensionKeyShare)
726				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
727					b.AddUint16(uint16(m.selectedGroup))
728				})
729			}
730
731			extensionsPresent = len(b.BytesOrPanic()) > 2
732		})
733
734		if !extensionsPresent {
735			*b = bWithoutExtensions
736		}
737	})
738
739	m.raw = b.BytesOrPanic()
740	return m.raw
741}
742
743func (m *serverHelloMsg) unmarshal(data []byte) bool {
744	*m = serverHelloMsg{raw: data}
745	s := cryptobyte.String(data)
746
747	if !s.Skip(4) || // message type and uint24 length field
748		!s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
749		!readUint8LengthPrefixed(&s, &m.sessionId) ||
750		!s.ReadUint16(&m.cipherSuite) ||
751		!s.ReadUint8(&m.compressionMethod) {
752		return false
753	}
754
755	if s.Empty() {
756		// ServerHello is optionally followed by extension data
757		return true
758	}
759
760	var extensions cryptobyte.String
761	if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
762		return false
763	}
764
765	for !extensions.Empty() {
766		var extension uint16
767		var extData cryptobyte.String
768		if !extensions.ReadUint16(&extension) ||
769			!extensions.ReadUint16LengthPrefixed(&extData) {
770			return false
771		}
772
773		switch extension {
774		case extensionNextProtoNeg:
775			m.nextProtoNeg = true
776			for !extData.Empty() {
777				var proto cryptobyte.String
778				if !extData.ReadUint8LengthPrefixed(&proto) ||
779					proto.Empty() {
780					return false
781				}
782				m.nextProtos = append(m.nextProtos, string(proto))
783			}
784		case extensionStatusRequest:
785			m.ocspStapling = true
786		case extensionSessionTicket:
787			m.ticketSupported = true
788		case extensionRenegotiationInfo:
789			if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
790				return false
791			}
792			m.secureRenegotiationSupported = true
793		case extensionALPN:
794			var protoList cryptobyte.String
795			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
796				return false
797			}
798			var proto cryptobyte.String
799			if !protoList.ReadUint8LengthPrefixed(&proto) ||
800				proto.Empty() || !protoList.Empty() {
801				return false
802			}
803			m.alpnProtocol = string(proto)
804		case extensionSCT:
805			var sctList cryptobyte.String
806			if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
807				return false
808			}
809			for !sctList.Empty() {
810				var sct []byte
811				if !readUint16LengthPrefixed(&sctList, &sct) ||
812					len(sct) == 0 {
813					return false
814				}
815				m.scts = append(m.scts, sct)
816			}
817		case extensionSupportedVersions:
818			if !extData.ReadUint16(&m.supportedVersion) {
819				return false
820			}
821		case extensionCookie:
822			if !readUint16LengthPrefixed(&extData, &m.cookie) ||
823				len(m.cookie) == 0 {
824				return false
825			}
826		case extensionKeyShare:
827			// This extension has different formats in SH and HRR, accept either
828			// and let the handshake logic decide. See RFC 8446, Section 4.2.8.
829			if len(extData) == 2 {
830				if !extData.ReadUint16((*uint16)(&m.selectedGroup)) {
831					return false
832				}
833			} else {
834				if !extData.ReadUint16((*uint16)(&m.serverShare.group)) ||
835					!readUint16LengthPrefixed(&extData, &m.serverShare.data) {
836					return false
837				}
838			}
839		case extensionPreSharedKey:
840			m.selectedIdentityPresent = true
841			if !extData.ReadUint16(&m.selectedIdentity) {
842				return false
843			}
844		default:
845			// Ignore unknown extensions.
846			continue
847		}
848
849		if !extData.Empty() {
850			return false
851		}
852	}
853
854	return true
855}
856
857type encryptedExtensionsMsg struct {
858	raw          []byte
859	alpnProtocol string
860}
861
862func (m *encryptedExtensionsMsg) marshal() []byte {
863	if m.raw != nil {
864		return m.raw
865	}
866
867	var b cryptobyte.Builder
868	b.AddUint8(typeEncryptedExtensions)
869	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
870		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
871			if len(m.alpnProtocol) > 0 {
872				b.AddUint16(extensionALPN)
873				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
874					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
875						b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
876							b.AddBytes([]byte(m.alpnProtocol))
877						})
878					})
879				})
880			}
881		})
882	})
883
884	m.raw = b.BytesOrPanic()
885	return m.raw
886}
887
888func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
889	*m = encryptedExtensionsMsg{raw: data}
890	s := cryptobyte.String(data)
891
892	var extensions cryptobyte.String
893	if !s.Skip(4) || // message type and uint24 length field
894		!s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
895		return false
896	}
897
898	for !extensions.Empty() {
899		var extension uint16
900		var extData cryptobyte.String
901		if !extensions.ReadUint16(&extension) ||
902			!extensions.ReadUint16LengthPrefixed(&extData) {
903			return false
904		}
905
906		switch extension {
907		case extensionALPN:
908			var protoList cryptobyte.String
909			if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
910				return false
911			}
912			var proto cryptobyte.String
913			if !protoList.ReadUint8LengthPrefixed(&proto) ||
914				proto.Empty() || !protoList.Empty() {
915				return false
916			}
917			m.alpnProtocol = string(proto)
918		default:
919			// Ignore unknown extensions.
920			continue
921		}
922
923		if !extData.Empty() {
924			return false
925		}
926	}
927
928	return true
929}
930
931type endOfEarlyDataMsg struct{}
932
933func (m *endOfEarlyDataMsg) marshal() []byte {
934	x := make([]byte, 4)
935	x[0] = typeEndOfEarlyData
936	return x
937}
938
939func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
940	return len(data) == 4
941}
942
943type keyUpdateMsg struct {
944	raw             []byte
945	updateRequested bool
946}
947
948func (m *keyUpdateMsg) marshal() []byte {
949	if m.raw != nil {
950		return m.raw
951	}
952
953	var b cryptobyte.Builder
954	b.AddUint8(typeKeyUpdate)
955	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
956		if m.updateRequested {
957			b.AddUint8(1)
958		} else {
959			b.AddUint8(0)
960		}
961	})
962
963	m.raw = b.BytesOrPanic()
964	return m.raw
965}
966
967func (m *keyUpdateMsg) unmarshal(data []byte) bool {
968	m.raw = data
969	s := cryptobyte.String(data)
970
971	var updateRequested uint8
972	if !s.Skip(4) || // message type and uint24 length field
973		!s.ReadUint8(&updateRequested) || !s.Empty() {
974		return false
975	}
976	switch updateRequested {
977	case 0:
978		m.updateRequested = false
979	case 1:
980		m.updateRequested = true
981	default:
982		return false
983	}
984	return true
985}
986
987type newSessionTicketMsgTLS13 struct {
988	raw          []byte
989	lifetime     uint32
990	ageAdd       uint32
991	nonce        []byte
992	label        []byte
993	maxEarlyData uint32
994}
995
996func (m *newSessionTicketMsgTLS13) marshal() []byte {
997	if m.raw != nil {
998		return m.raw
999	}
1000
1001	var b cryptobyte.Builder
1002	b.AddUint8(typeNewSessionTicket)
1003	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1004		b.AddUint32(m.lifetime)
1005		b.AddUint32(m.ageAdd)
1006		b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
1007			b.AddBytes(m.nonce)
1008		})
1009		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1010			b.AddBytes(m.label)
1011		})
1012
1013		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1014			if m.maxEarlyData > 0 {
1015				b.AddUint16(extensionEarlyData)
1016				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1017					b.AddUint32(m.maxEarlyData)
1018				})
1019			}
1020		})
1021	})
1022
1023	m.raw = b.BytesOrPanic()
1024	return m.raw
1025}
1026
1027func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
1028	*m = newSessionTicketMsgTLS13{raw: data}
1029	s := cryptobyte.String(data)
1030
1031	var extensions cryptobyte.String
1032	if !s.Skip(4) || // message type and uint24 length field
1033		!s.ReadUint32(&m.lifetime) ||
1034		!s.ReadUint32(&m.ageAdd) ||
1035		!readUint8LengthPrefixed(&s, &m.nonce) ||
1036		!readUint16LengthPrefixed(&s, &m.label) ||
1037		!s.ReadUint16LengthPrefixed(&extensions) ||
1038		!s.Empty() {
1039		return false
1040	}
1041
1042	for !extensions.Empty() {
1043		var extension uint16
1044		var extData cryptobyte.String
1045		if !extensions.ReadUint16(&extension) ||
1046			!extensions.ReadUint16LengthPrefixed(&extData) {
1047			return false
1048		}
1049
1050		switch extension {
1051		case extensionEarlyData:
1052			if !extData.ReadUint32(&m.maxEarlyData) {
1053				return false
1054			}
1055		default:
1056			// Ignore unknown extensions.
1057			continue
1058		}
1059
1060		if !extData.Empty() {
1061			return false
1062		}
1063	}
1064
1065	return true
1066}
1067
1068type certificateRequestMsgTLS13 struct {
1069	raw                              []byte
1070	ocspStapling                     bool
1071	scts                             bool
1072	supportedSignatureAlgorithms     []SignatureScheme
1073	supportedSignatureAlgorithmsCert []SignatureScheme
1074	certificateAuthorities           [][]byte
1075}
1076
1077func (m *certificateRequestMsgTLS13) marshal() []byte {
1078	if m.raw != nil {
1079		return m.raw
1080	}
1081
1082	var b cryptobyte.Builder
1083	b.AddUint8(typeCertificateRequest)
1084	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1085		// certificate_request_context (SHALL be zero length unless used for
1086		// post-handshake authentication)
1087		b.AddUint8(0)
1088
1089		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1090			if m.ocspStapling {
1091				b.AddUint16(extensionStatusRequest)
1092				b.AddUint16(0) // empty extension_data
1093			}
1094			if m.scts {
1095				// RFC 8446, Section 4.4.2.1 makes no mention of
1096				// signed_certificate_timestamp in CertificateRequest, but
1097				// "Extensions in the Certificate message from the client MUST
1098				// correspond to extensions in the CertificateRequest message
1099				// from the server." and it appears in the table in Section 4.2.
1100				b.AddUint16(extensionSCT)
1101				b.AddUint16(0) // empty extension_data
1102			}
1103			if len(m.supportedSignatureAlgorithms) > 0 {
1104				b.AddUint16(extensionSignatureAlgorithms)
1105				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1106					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1107						for _, sigAlgo := range m.supportedSignatureAlgorithms {
1108							b.AddUint16(uint16(sigAlgo))
1109						}
1110					})
1111				})
1112			}
1113			if len(m.supportedSignatureAlgorithmsCert) > 0 {
1114				b.AddUint16(extensionSignatureAlgorithmsCert)
1115				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1116					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1117						for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
1118							b.AddUint16(uint16(sigAlgo))
1119						}
1120					})
1121				})
1122			}
1123			if len(m.certificateAuthorities) > 0 {
1124				b.AddUint16(extensionCertificateAuthorities)
1125				b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1126					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1127						for _, ca := range m.certificateAuthorities {
1128							b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1129								b.AddBytes(ca)
1130							})
1131						}
1132					})
1133				})
1134			}
1135		})
1136	})
1137
1138	m.raw = b.BytesOrPanic()
1139	return m.raw
1140}
1141
1142func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
1143	*m = certificateRequestMsgTLS13{raw: data}
1144	s := cryptobyte.String(data)
1145
1146	var context, extensions cryptobyte.String
1147	if !s.Skip(4) || // message type and uint24 length field
1148		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
1149		!s.ReadUint16LengthPrefixed(&extensions) ||
1150		!s.Empty() {
1151		return false
1152	}
1153
1154	for !extensions.Empty() {
1155		var extension uint16
1156		var extData cryptobyte.String
1157		if !extensions.ReadUint16(&extension) ||
1158			!extensions.ReadUint16LengthPrefixed(&extData) {
1159			return false
1160		}
1161
1162		switch extension {
1163		case extensionStatusRequest:
1164			m.ocspStapling = true
1165		case extensionSCT:
1166			m.scts = true
1167		case extensionSignatureAlgorithms:
1168			var sigAndAlgs cryptobyte.String
1169			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
1170				return false
1171			}
1172			for !sigAndAlgs.Empty() {
1173				var sigAndAlg uint16
1174				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
1175					return false
1176				}
1177				m.supportedSignatureAlgorithms = append(
1178					m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
1179			}
1180		case extensionSignatureAlgorithmsCert:
1181			var sigAndAlgs cryptobyte.String
1182			if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
1183				return false
1184			}
1185			for !sigAndAlgs.Empty() {
1186				var sigAndAlg uint16
1187				if !sigAndAlgs.ReadUint16(&sigAndAlg) {
1188					return false
1189				}
1190				m.supportedSignatureAlgorithmsCert = append(
1191					m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg))
1192			}
1193		case extensionCertificateAuthorities:
1194			var auths cryptobyte.String
1195			if !extData.ReadUint16LengthPrefixed(&auths) || auths.Empty() {
1196				return false
1197			}
1198			for !auths.Empty() {
1199				var ca []byte
1200				if !readUint16LengthPrefixed(&auths, &ca) || len(ca) == 0 {
1201					return false
1202				}
1203				m.certificateAuthorities = append(m.certificateAuthorities, ca)
1204			}
1205		default:
1206			// Ignore unknown extensions.
1207			continue
1208		}
1209
1210		if !extData.Empty() {
1211			return false
1212		}
1213	}
1214
1215	return true
1216}
1217
1218type certificateMsg struct {
1219	raw          []byte
1220	certificates [][]byte
1221}
1222
1223func (m *certificateMsg) marshal() (x []byte) {
1224	if m.raw != nil {
1225		return m.raw
1226	}
1227
1228	var i int
1229	for _, slice := range m.certificates {
1230		i += len(slice)
1231	}
1232
1233	length := 3 + 3*len(m.certificates) + i
1234	x = make([]byte, 4+length)
1235	x[0] = typeCertificate
1236	x[1] = uint8(length >> 16)
1237	x[2] = uint8(length >> 8)
1238	x[3] = uint8(length)
1239
1240	certificateOctets := length - 3
1241	x[4] = uint8(certificateOctets >> 16)
1242	x[5] = uint8(certificateOctets >> 8)
1243	x[6] = uint8(certificateOctets)
1244
1245	y := x[7:]
1246	for _, slice := range m.certificates {
1247		y[0] = uint8(len(slice) >> 16)
1248		y[1] = uint8(len(slice) >> 8)
1249		y[2] = uint8(len(slice))
1250		copy(y[3:], slice)
1251		y = y[3+len(slice):]
1252	}
1253
1254	m.raw = x
1255	return
1256}
1257
1258func (m *certificateMsg) unmarshal(data []byte) bool {
1259	if len(data) < 7 {
1260		return false
1261	}
1262
1263	m.raw = data
1264	certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6])
1265	if uint32(len(data)) != certsLen+7 {
1266		return false
1267	}
1268
1269	numCerts := 0
1270	d := data[7:]
1271	for certsLen > 0 {
1272		if len(d) < 4 {
1273			return false
1274		}
1275		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
1276		if uint32(len(d)) < 3+certLen {
1277			return false
1278		}
1279		d = d[3+certLen:]
1280		certsLen -= 3 + certLen
1281		numCerts++
1282	}
1283
1284	m.certificates = make([][]byte, numCerts)
1285	d = data[7:]
1286	for i := 0; i < numCerts; i++ {
1287		certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2])
1288		m.certificates[i] = d[3 : 3+certLen]
1289		d = d[3+certLen:]
1290	}
1291
1292	return true
1293}
1294
1295type certificateMsgTLS13 struct {
1296	raw          []byte
1297	certificate  Certificate
1298	ocspStapling bool
1299	scts         bool
1300}
1301
1302func (m *certificateMsgTLS13) marshal() []byte {
1303	if m.raw != nil {
1304		return m.raw
1305	}
1306
1307	var b cryptobyte.Builder
1308	b.AddUint8(typeCertificate)
1309	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1310		b.AddUint8(0) // certificate_request_context
1311
1312		certificate := m.certificate
1313		if !m.ocspStapling {
1314			certificate.OCSPStaple = nil
1315		}
1316		if !m.scts {
1317			certificate.SignedCertificateTimestamps = nil
1318		}
1319		marshalCertificate(b, certificate)
1320	})
1321
1322	m.raw = b.BytesOrPanic()
1323	return m.raw
1324}
1325
1326func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
1327	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1328		for i, cert := range certificate.Certificate {
1329			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1330				b.AddBytes(cert)
1331			})
1332			b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1333				if i > 0 {
1334					// This library only supports OCSP and SCT for leaf certificates.
1335					return
1336				}
1337				if certificate.OCSPStaple != nil {
1338					b.AddUint16(extensionStatusRequest)
1339					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1340						b.AddUint8(statusTypeOCSP)
1341						b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1342							b.AddBytes(certificate.OCSPStaple)
1343						})
1344					})
1345				}
1346				if certificate.SignedCertificateTimestamps != nil {
1347					b.AddUint16(extensionSCT)
1348					b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1349						b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1350							for _, sct := range certificate.SignedCertificateTimestamps {
1351								b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1352									b.AddBytes(sct)
1353								})
1354							}
1355						})
1356					})
1357				}
1358			})
1359		}
1360	})
1361}
1362
1363func (m *certificateMsgTLS13) unmarshal(data []byte) bool {
1364	*m = certificateMsgTLS13{raw: data}
1365	s := cryptobyte.String(data)
1366
1367	var context cryptobyte.String
1368	if !s.Skip(4) || // message type and uint24 length field
1369		!s.ReadUint8LengthPrefixed(&context) || !context.Empty() ||
1370		!unmarshalCertificate(&s, &m.certificate) ||
1371		!s.Empty() {
1372		return false
1373	}
1374
1375	m.scts = m.certificate.SignedCertificateTimestamps != nil
1376	m.ocspStapling = m.certificate.OCSPStaple != nil
1377
1378	return true
1379}
1380
1381func unmarshalCertificate(s *cryptobyte.String, certificate *Certificate) bool {
1382	var certList cryptobyte.String
1383	if !s.ReadUint24LengthPrefixed(&certList) {
1384		return false
1385	}
1386	for !certList.Empty() {
1387		var cert []byte
1388		var extensions cryptobyte.String
1389		if !readUint24LengthPrefixed(&certList, &cert) ||
1390			!certList.ReadUint16LengthPrefixed(&extensions) {
1391			return false
1392		}
1393		certificate.Certificate = append(certificate.Certificate, cert)
1394		for !extensions.Empty() {
1395			var extension uint16
1396			var extData cryptobyte.String
1397			if !extensions.ReadUint16(&extension) ||
1398				!extensions.ReadUint16LengthPrefixed(&extData) {
1399				return false
1400			}
1401			if len(certificate.Certificate) > 1 {
1402				// This library only supports OCSP and SCT for leaf certificates.
1403				continue
1404			}
1405
1406			switch extension {
1407			case extensionStatusRequest:
1408				var statusType uint8
1409				if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
1410					!readUint24LengthPrefixed(&extData, &certificate.OCSPStaple) ||
1411					len(certificate.OCSPStaple) == 0 {
1412					return false
1413				}
1414			case extensionSCT:
1415				var sctList cryptobyte.String
1416				if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
1417					return false
1418				}
1419				for !sctList.Empty() {
1420					var sct []byte
1421					if !readUint16LengthPrefixed(&sctList, &sct) ||
1422						len(sct) == 0 {
1423						return false
1424					}
1425					certificate.SignedCertificateTimestamps = append(
1426						certificate.SignedCertificateTimestamps, sct)
1427				}
1428			default:
1429				// Ignore unknown extensions.
1430				continue
1431			}
1432
1433			if !extData.Empty() {
1434				return false
1435			}
1436		}
1437	}
1438	return true
1439}
1440
1441type serverKeyExchangeMsg struct {
1442	raw []byte
1443	key []byte
1444}
1445
1446func (m *serverKeyExchangeMsg) marshal() []byte {
1447	if m.raw != nil {
1448		return m.raw
1449	}
1450	length := len(m.key)
1451	x := make([]byte, length+4)
1452	x[0] = typeServerKeyExchange
1453	x[1] = uint8(length >> 16)
1454	x[2] = uint8(length >> 8)
1455	x[3] = uint8(length)
1456	copy(x[4:], m.key)
1457
1458	m.raw = x
1459	return x
1460}
1461
1462func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
1463	m.raw = data
1464	if len(data) < 4 {
1465		return false
1466	}
1467	m.key = data[4:]
1468	return true
1469}
1470
1471type certificateStatusMsg struct {
1472	raw      []byte
1473	response []byte
1474}
1475
1476func (m *certificateStatusMsg) marshal() []byte {
1477	if m.raw != nil {
1478		return m.raw
1479	}
1480
1481	var b cryptobyte.Builder
1482	b.AddUint8(typeCertificateStatus)
1483	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1484		b.AddUint8(statusTypeOCSP)
1485		b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1486			b.AddBytes(m.response)
1487		})
1488	})
1489
1490	m.raw = b.BytesOrPanic()
1491	return m.raw
1492}
1493
1494func (m *certificateStatusMsg) unmarshal(data []byte) bool {
1495	m.raw = data
1496	s := cryptobyte.String(data)
1497
1498	var statusType uint8
1499	if !s.Skip(4) || // message type and uint24 length field
1500		!s.ReadUint8(&statusType) || statusType != statusTypeOCSP ||
1501		!readUint24LengthPrefixed(&s, &m.response) ||
1502		len(m.response) == 0 || !s.Empty() {
1503		return false
1504	}
1505	return true
1506}
1507
1508type serverHelloDoneMsg struct{}
1509
1510func (m *serverHelloDoneMsg) marshal() []byte {
1511	x := make([]byte, 4)
1512	x[0] = typeServerHelloDone
1513	return x
1514}
1515
1516func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
1517	return len(data) == 4
1518}
1519
1520type clientKeyExchangeMsg struct {
1521	raw        []byte
1522	ciphertext []byte
1523}
1524
1525func (m *clientKeyExchangeMsg) marshal() []byte {
1526	if m.raw != nil {
1527		return m.raw
1528	}
1529	length := len(m.ciphertext)
1530	x := make([]byte, length+4)
1531	x[0] = typeClientKeyExchange
1532	x[1] = uint8(length >> 16)
1533	x[2] = uint8(length >> 8)
1534	x[3] = uint8(length)
1535	copy(x[4:], m.ciphertext)
1536
1537	m.raw = x
1538	return x
1539}
1540
1541func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
1542	m.raw = data
1543	if len(data) < 4 {
1544		return false
1545	}
1546	l := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1547	if l != len(data)-4 {
1548		return false
1549	}
1550	m.ciphertext = data[4:]
1551	return true
1552}
1553
1554type finishedMsg struct {
1555	raw        []byte
1556	verifyData []byte
1557}
1558
1559func (m *finishedMsg) marshal() []byte {
1560	if m.raw != nil {
1561		return m.raw
1562	}
1563
1564	var b cryptobyte.Builder
1565	b.AddUint8(typeFinished)
1566	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1567		b.AddBytes(m.verifyData)
1568	})
1569
1570	m.raw = b.BytesOrPanic()
1571	return m.raw
1572}
1573
1574func (m *finishedMsg) unmarshal(data []byte) bool {
1575	m.raw = data
1576	s := cryptobyte.String(data)
1577	return s.Skip(1) &&
1578		readUint24LengthPrefixed(&s, &m.verifyData) &&
1579		s.Empty()
1580}
1581
1582type nextProtoMsg struct {
1583	raw   []byte
1584	proto string
1585}
1586
1587func (m *nextProtoMsg) marshal() []byte {
1588	if m.raw != nil {
1589		return m.raw
1590	}
1591	l := len(m.proto)
1592	if l > 255 {
1593		l = 255
1594	}
1595
1596	padding := 32 - (l+2)%32
1597	length := l + padding + 2
1598	x := make([]byte, length+4)
1599	x[0] = typeNextProtocol
1600	x[1] = uint8(length >> 16)
1601	x[2] = uint8(length >> 8)
1602	x[3] = uint8(length)
1603
1604	y := x[4:]
1605	y[0] = byte(l)
1606	copy(y[1:], []byte(m.proto[0:l]))
1607	y = y[1+l:]
1608	y[0] = byte(padding)
1609
1610	m.raw = x
1611
1612	return x
1613}
1614
1615func (m *nextProtoMsg) unmarshal(data []byte) bool {
1616	m.raw = data
1617
1618	if len(data) < 5 {
1619		return false
1620	}
1621	data = data[4:]
1622	protoLen := int(data[0])
1623	data = data[1:]
1624	if len(data) < protoLen {
1625		return false
1626	}
1627	m.proto = string(data[0:protoLen])
1628	data = data[protoLen:]
1629
1630	if len(data) < 1 {
1631		return false
1632	}
1633	paddingLen := int(data[0])
1634	data = data[1:]
1635	if len(data) != paddingLen {
1636		return false
1637	}
1638
1639	return true
1640}
1641
1642type certificateRequestMsg struct {
1643	raw []byte
1644	// hasSignatureAlgorithm indicates whether this message includes a list of
1645	// supported signature algorithms. This change was introduced with TLS 1.2.
1646	hasSignatureAlgorithm bool
1647
1648	certificateTypes             []byte
1649	supportedSignatureAlgorithms []SignatureScheme
1650	certificateAuthorities       [][]byte
1651}
1652
1653func (m *certificateRequestMsg) marshal() (x []byte) {
1654	if m.raw != nil {
1655		return m.raw
1656	}
1657
1658	// See RFC 4346, Section 7.4.4.
1659	length := 1 + len(m.certificateTypes) + 2
1660	casLength := 0
1661	for _, ca := range m.certificateAuthorities {
1662		casLength += 2 + len(ca)
1663	}
1664	length += casLength
1665
1666	if m.hasSignatureAlgorithm {
1667		length += 2 + 2*len(m.supportedSignatureAlgorithms)
1668	}
1669
1670	x = make([]byte, 4+length)
1671	x[0] = typeCertificateRequest
1672	x[1] = uint8(length >> 16)
1673	x[2] = uint8(length >> 8)
1674	x[3] = uint8(length)
1675
1676	x[4] = uint8(len(m.certificateTypes))
1677
1678	copy(x[5:], m.certificateTypes)
1679	y := x[5+len(m.certificateTypes):]
1680
1681	if m.hasSignatureAlgorithm {
1682		n := len(m.supportedSignatureAlgorithms) * 2
1683		y[0] = uint8(n >> 8)
1684		y[1] = uint8(n)
1685		y = y[2:]
1686		for _, sigAlgo := range m.supportedSignatureAlgorithms {
1687			y[0] = uint8(sigAlgo >> 8)
1688			y[1] = uint8(sigAlgo)
1689			y = y[2:]
1690		}
1691	}
1692
1693	y[0] = uint8(casLength >> 8)
1694	y[1] = uint8(casLength)
1695	y = y[2:]
1696	for _, ca := range m.certificateAuthorities {
1697		y[0] = uint8(len(ca) >> 8)
1698		y[1] = uint8(len(ca))
1699		y = y[2:]
1700		copy(y, ca)
1701		y = y[len(ca):]
1702	}
1703
1704	m.raw = x
1705	return
1706}
1707
1708func (m *certificateRequestMsg) unmarshal(data []byte) bool {
1709	m.raw = data
1710
1711	if len(data) < 5 {
1712		return false
1713	}
1714
1715	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1716	if uint32(len(data))-4 != length {
1717		return false
1718	}
1719
1720	numCertTypes := int(data[4])
1721	data = data[5:]
1722	if numCertTypes == 0 || len(data) <= numCertTypes {
1723		return false
1724	}
1725
1726	m.certificateTypes = make([]byte, numCertTypes)
1727	if copy(m.certificateTypes, data) != numCertTypes {
1728		return false
1729	}
1730
1731	data = data[numCertTypes:]
1732
1733	if m.hasSignatureAlgorithm {
1734		if len(data) < 2 {
1735			return false
1736		}
1737		sigAndHashLen := uint16(data[0])<<8 | uint16(data[1])
1738		data = data[2:]
1739		if sigAndHashLen&1 != 0 {
1740			return false
1741		}
1742		if len(data) < int(sigAndHashLen) {
1743			return false
1744		}
1745		numSigAlgos := sigAndHashLen / 2
1746		m.supportedSignatureAlgorithms = make([]SignatureScheme, numSigAlgos)
1747		for i := range m.supportedSignatureAlgorithms {
1748			m.supportedSignatureAlgorithms[i] = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
1749			data = data[2:]
1750		}
1751	}
1752
1753	if len(data) < 2 {
1754		return false
1755	}
1756	casLength := uint16(data[0])<<8 | uint16(data[1])
1757	data = data[2:]
1758	if len(data) < int(casLength) {
1759		return false
1760	}
1761	cas := make([]byte, casLength)
1762	copy(cas, data)
1763	data = data[casLength:]
1764
1765	m.certificateAuthorities = nil
1766	for len(cas) > 0 {
1767		if len(cas) < 2 {
1768			return false
1769		}
1770		caLen := uint16(cas[0])<<8 | uint16(cas[1])
1771		cas = cas[2:]
1772
1773		if len(cas) < int(caLen) {
1774			return false
1775		}
1776
1777		m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen])
1778		cas = cas[caLen:]
1779	}
1780
1781	return len(data) == 0
1782}
1783
1784type certificateVerifyMsg struct {
1785	raw                   []byte
1786	hasSignatureAlgorithm bool // format change introduced in TLS 1.2
1787	signatureAlgorithm    SignatureScheme
1788	signature             []byte
1789}
1790
1791func (m *certificateVerifyMsg) marshal() (x []byte) {
1792	if m.raw != nil {
1793		return m.raw
1794	}
1795
1796	var b cryptobyte.Builder
1797	b.AddUint8(typeCertificateVerify)
1798	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
1799		if m.hasSignatureAlgorithm {
1800			b.AddUint16(uint16(m.signatureAlgorithm))
1801		}
1802		b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
1803			b.AddBytes(m.signature)
1804		})
1805	})
1806
1807	m.raw = b.BytesOrPanic()
1808	return m.raw
1809}
1810
1811func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
1812	m.raw = data
1813	s := cryptobyte.String(data)
1814
1815	if !s.Skip(4) { // message type and uint24 length field
1816		return false
1817	}
1818	if m.hasSignatureAlgorithm {
1819		if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
1820			return false
1821		}
1822	}
1823	return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
1824}
1825
1826type newSessionTicketMsg struct {
1827	raw    []byte
1828	ticket []byte
1829}
1830
1831func (m *newSessionTicketMsg) marshal() (x []byte) {
1832	if m.raw != nil {
1833		return m.raw
1834	}
1835
1836	// See RFC 5077, Section 3.3.
1837	ticketLen := len(m.ticket)
1838	length := 2 + 4 + ticketLen
1839	x = make([]byte, 4+length)
1840	x[0] = typeNewSessionTicket
1841	x[1] = uint8(length >> 16)
1842	x[2] = uint8(length >> 8)
1843	x[3] = uint8(length)
1844	x[8] = uint8(ticketLen >> 8)
1845	x[9] = uint8(ticketLen)
1846	copy(x[10:], m.ticket)
1847
1848	m.raw = x
1849
1850	return
1851}
1852
1853func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
1854	m.raw = data
1855
1856	if len(data) < 10 {
1857		return false
1858	}
1859
1860	length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
1861	if uint32(len(data))-4 != length {
1862		return false
1863	}
1864
1865	ticketLen := int(data[8])<<8 + int(data[9])
1866	if len(data)-10 != ticketLen {
1867		return false
1868	}
1869
1870	m.ticket = data[10:]
1871
1872	return true
1873}
1874
1875type helloRequestMsg struct {
1876}
1877
1878func (*helloRequestMsg) marshal() []byte {
1879	return []byte{typeHelloRequest, 0, 0, 0}
1880}
1881
1882func (*helloRequestMsg) unmarshal(data []byte) bool {
1883	return len(data) == 4
1884}
1885