1// Copyright 2017 Google Inc. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"errors"
9	"io"
10)
11
12type TLSExtension interface {
13	writeToUConn(*UConn) error
14
15	Len() int // includes header
16
17	// Read reads up to len(p) bytes into p.
18	// It returns the number of bytes read (0 <= n <= len(p)) and any error encountered.
19	Read(p []byte) (n int, err error) // implements io.Reader
20}
21
22type NPNExtension struct {
23	NextProtos []string
24}
25
26func (e *NPNExtension) writeToUConn(uc *UConn) error {
27	uc.config.NextProtos = e.NextProtos
28	uc.HandshakeState.Hello.NextProtoNeg = true
29	return nil
30}
31
32func (e *NPNExtension) Len() int {
33	return 4
34}
35
36func (e *NPNExtension) Read(b []byte) (int, error) {
37	if len(b) < e.Len() {
38		return 0, io.ErrShortBuffer
39	}
40	b[0] = byte(extensionNextProtoNeg >> 8)
41	b[1] = byte(extensionNextProtoNeg & 0xff)
42	// The length is always 0
43	return e.Len(), io.EOF
44}
45
46type SNIExtension struct {
47	ServerName string // not an array because go crypto/tls doesn't support multiple SNIs
48}
49
50func (e *SNIExtension) writeToUConn(uc *UConn) error {
51	uc.config.ServerName = e.ServerName
52	uc.HandshakeState.Hello.ServerName = e.ServerName
53	return nil
54}
55
56func (e *SNIExtension) Len() int {
57	return 4 + 2 + 1 + 2 + len(e.ServerName)
58}
59
60func (e *SNIExtension) Read(b []byte) (int, error) {
61	if len(b) < e.Len() {
62		return 0, io.ErrShortBuffer
63	}
64	// RFC 3546, section 3.1
65	b[0] = byte(extensionServerName >> 8)
66	b[1] = byte(extensionServerName)
67	b[2] = byte((len(e.ServerName) + 5) >> 8)
68	b[3] = byte((len(e.ServerName) + 5))
69	b[4] = byte((len(e.ServerName) + 3) >> 8)
70	b[5] = byte(len(e.ServerName) + 3)
71	// b[6] Server Name Type: host_name (0)
72	b[7] = byte(len(e.ServerName) >> 8)
73	b[8] = byte(len(e.ServerName))
74	copy(b[9:], []byte(e.ServerName))
75	return e.Len(), io.EOF
76}
77
78type StatusRequestExtension struct {
79}
80
81func (e *StatusRequestExtension) writeToUConn(uc *UConn) error {
82	uc.HandshakeState.Hello.OcspStapling = true
83	return nil
84}
85
86func (e *StatusRequestExtension) Len() int {
87	return 9
88}
89
90func (e *StatusRequestExtension) Read(b []byte) (int, error) {
91	if len(b) < e.Len() {
92		return 0, io.ErrShortBuffer
93	}
94	// RFC 4366, section 3.6
95	b[0] = byte(extensionStatusRequest >> 8)
96	b[1] = byte(extensionStatusRequest)
97	b[2] = 0
98	b[3] = 5
99	b[4] = 1 // OCSP type
100	// Two zero valued uint16s for the two lengths.
101	return e.Len(), io.EOF
102}
103
104type SupportedCurvesExtension struct {
105	Curves []CurveID
106}
107
108func (e *SupportedCurvesExtension) writeToUConn(uc *UConn) error {
109	uc.config.CurvePreferences = e.Curves
110	uc.HandshakeState.Hello.SupportedCurves = e.Curves
111	return nil
112}
113
114func (e *SupportedCurvesExtension) Len() int {
115	return 6 + 2*len(e.Curves)
116}
117
118func (e *SupportedCurvesExtension) Read(b []byte) (int, error) {
119	if len(b) < e.Len() {
120		return 0, io.ErrShortBuffer
121	}
122	// http://tools.ietf.org/html/rfc4492#section-5.5.1
123	b[0] = byte(extensionSupportedCurves >> 8)
124	b[1] = byte(extensionSupportedCurves)
125	b[2] = byte((2 + 2*len(e.Curves)) >> 8)
126	b[3] = byte((2 + 2*len(e.Curves)))
127	b[4] = byte((2 * len(e.Curves)) >> 8)
128	b[5] = byte((2 * len(e.Curves)))
129	for i, curve := range e.Curves {
130		b[6+2*i] = byte(curve >> 8)
131		b[7+2*i] = byte(curve)
132	}
133	return e.Len(), io.EOF
134}
135
136type SupportedPointsExtension struct {
137	SupportedPoints []uint8
138}
139
140func (e *SupportedPointsExtension) writeToUConn(uc *UConn) error {
141	uc.HandshakeState.Hello.SupportedPoints = e.SupportedPoints
142	return nil
143}
144
145func (e *SupportedPointsExtension) Len() int {
146	return 5 + len(e.SupportedPoints)
147}
148
149func (e *SupportedPointsExtension) Read(b []byte) (int, error) {
150	if len(b) < e.Len() {
151		return 0, io.ErrShortBuffer
152	}
153	// http://tools.ietf.org/html/rfc4492#section-5.5.2
154	b[0] = byte(extensionSupportedPoints >> 8)
155	b[1] = byte(extensionSupportedPoints)
156	b[2] = byte((1 + len(e.SupportedPoints)) >> 8)
157	b[3] = byte((1 + len(e.SupportedPoints)))
158	b[4] = byte((len(e.SupportedPoints)))
159	for i, pointFormat := range e.SupportedPoints {
160		b[5+i] = pointFormat
161	}
162	return e.Len(), io.EOF
163}
164
165type SignatureAlgorithmsExtension struct {
166	SupportedSignatureAlgorithms []SignatureScheme
167}
168
169func (e *SignatureAlgorithmsExtension) writeToUConn(uc *UConn) error {
170	uc.HandshakeState.Hello.SupportedSignatureAlgorithms = e.SupportedSignatureAlgorithms
171	return nil
172}
173
174func (e *SignatureAlgorithmsExtension) Len() int {
175	return 6 + 2*len(e.SupportedSignatureAlgorithms)
176}
177
178func (e *SignatureAlgorithmsExtension) Read(b []byte) (int, error) {
179	if len(b) < e.Len() {
180		return 0, io.ErrShortBuffer
181	}
182	// https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1
183	b[0] = byte(extensionSignatureAlgorithms >> 8)
184	b[1] = byte(extensionSignatureAlgorithms)
185	b[2] = byte((2 + 2*len(e.SupportedSignatureAlgorithms)) >> 8)
186	b[3] = byte((2 + 2*len(e.SupportedSignatureAlgorithms)))
187	b[4] = byte((2 * len(e.SupportedSignatureAlgorithms)) >> 8)
188	b[5] = byte((2 * len(e.SupportedSignatureAlgorithms)))
189	for i, sigAndHash := range e.SupportedSignatureAlgorithms {
190		b[6+2*i] = byte(sigAndHash >> 8)
191		b[7+2*i] = byte(sigAndHash)
192	}
193	return e.Len(), io.EOF
194}
195
196type RenegotiationInfoExtension struct {
197	// Renegotiation field limits how many times client will perform renegotiation: no limit, once, or never.
198	// The extension still will be sent, even if Renegotiation is set to RenegotiateNever.
199	Renegotiation RenegotiationSupport
200}
201
202func (e *RenegotiationInfoExtension) writeToUConn(uc *UConn) error {
203	uc.config.Renegotiation = e.Renegotiation
204	switch e.Renegotiation {
205	case RenegotiateOnceAsClient:
206		fallthrough
207	case RenegotiateFreelyAsClient:
208		uc.HandshakeState.Hello.SecureRenegotiationSupported = true
209	case RenegotiateNever:
210	default:
211	}
212	return nil
213}
214
215func (e *RenegotiationInfoExtension) Len() int {
216	return 5
217}
218
219func (e *RenegotiationInfoExtension) Read(b []byte) (int, error) {
220	if len(b) < e.Len() {
221		return 0, io.ErrShortBuffer
222	}
223
224	var extInnerBody []byte // inner body is empty
225	innerBodyLen := len(extInnerBody)
226	extBodyLen := innerBodyLen + 1
227
228	b[0] = byte(extensionRenegotiationInfo >> 8)
229	b[1] = byte(extensionRenegotiationInfo & 0xff)
230	b[2] = byte(extBodyLen >> 8)
231	b[3] = byte(extBodyLen)
232	b[4] = byte(innerBodyLen)
233	copy(b[5:], extInnerBody)
234
235	return e.Len(), io.EOF
236}
237
238type ALPNExtension struct {
239	AlpnProtocols []string
240}
241
242func (e *ALPNExtension) writeToUConn(uc *UConn) error {
243	uc.config.NextProtos = e.AlpnProtocols
244	uc.HandshakeState.Hello.AlpnProtocols = e.AlpnProtocols
245	return nil
246}
247
248func (e *ALPNExtension) Len() int {
249	bLen := 2 + 2 + 2
250	for _, s := range e.AlpnProtocols {
251		bLen += 1 + len(s)
252	}
253	return bLen
254}
255
256func (e *ALPNExtension) Read(b []byte) (int, error) {
257	if len(b) < e.Len() {
258		return 0, io.ErrShortBuffer
259	}
260
261	b[0] = byte(extensionALPN >> 8)
262	b[1] = byte(extensionALPN & 0xff)
263	lengths := b[2:]
264	b = b[6:]
265
266	stringsLength := 0
267	for _, s := range e.AlpnProtocols {
268		l := len(s)
269		b[0] = byte(l)
270		copy(b[1:], s)
271		b = b[1+l:]
272		stringsLength += 1 + l
273	}
274
275	lengths[2] = byte(stringsLength >> 8)
276	lengths[3] = byte(stringsLength)
277	stringsLength += 2
278	lengths[0] = byte(stringsLength >> 8)
279	lengths[1] = byte(stringsLength)
280
281	return e.Len(), io.EOF
282}
283
284type SCTExtension struct {
285}
286
287func (e *SCTExtension) writeToUConn(uc *UConn) error {
288	uc.HandshakeState.Hello.Scts = true
289	return nil
290}
291
292func (e *SCTExtension) Len() int {
293	return 4
294}
295
296func (e *SCTExtension) Read(b []byte) (int, error) {
297	if len(b) < e.Len() {
298		return 0, io.ErrShortBuffer
299	}
300	// https://tools.ietf.org/html/rfc6962#section-3.3.1
301	b[0] = byte(extensionSCT >> 8)
302	b[1] = byte(extensionSCT)
303	// zero uint16 for the zero-length extension_data
304	return e.Len(), io.EOF
305}
306
307type SessionTicketExtension struct {
308	Session *ClientSessionState
309}
310
311func (e *SessionTicketExtension) writeToUConn(uc *UConn) error {
312	if e.Session != nil {
313		uc.HandshakeState.Session = e.Session
314		uc.HandshakeState.Hello.SessionTicket = e.Session.sessionTicket
315	}
316	return nil
317}
318
319func (e *SessionTicketExtension) Len() int {
320	if e.Session != nil {
321		return 4 + len(e.Session.sessionTicket)
322	}
323	return 4
324}
325
326func (e *SessionTicketExtension) Read(b []byte) (int, error) {
327	if len(b) < e.Len() {
328		return 0, io.ErrShortBuffer
329	}
330
331	extBodyLen := e.Len() - 4
332
333	b[0] = byte(extensionSessionTicket >> 8)
334	b[1] = byte(extensionSessionTicket)
335	b[2] = byte(extBodyLen >> 8)
336	b[3] = byte(extBodyLen)
337	if extBodyLen > 0 {
338		copy(b[4:], e.Session.sessionTicket)
339	}
340	return e.Len(), io.EOF
341}
342
343// GenericExtension allows to include in ClientHello arbitrary unsupported extensions.
344type GenericExtension struct {
345	Id   uint16
346	Data []byte
347}
348
349func (e *GenericExtension) writeToUConn(uc *UConn) error {
350	return nil
351}
352
353func (e *GenericExtension) Len() int {
354	return 4 + len(e.Data)
355}
356
357func (e *GenericExtension) Read(b []byte) (int, error) {
358	if len(b) < e.Len() {
359		return 0, io.ErrShortBuffer
360	}
361
362	b[0] = byte(e.Id >> 8)
363	b[1] = byte(e.Id)
364	b[2] = byte(len(e.Data) >> 8)
365	b[3] = byte(len(e.Data))
366	if len(e.Data) > 0 {
367		copy(b[4:], e.Data)
368	}
369	return e.Len(), io.EOF
370}
371
372type UtlsExtendedMasterSecretExtension struct {
373}
374
375// TODO: update when this extension is implemented in crypto/tls
376// but we probably won't have to enable it in Config
377func (e *UtlsExtendedMasterSecretExtension) writeToUConn(uc *UConn) error {
378	uc.HandshakeState.Hello.Ems = true
379	return nil
380}
381
382func (e *UtlsExtendedMasterSecretExtension) Len() int {
383	return 4
384}
385
386func (e *UtlsExtendedMasterSecretExtension) Read(b []byte) (int, error) {
387	if len(b) < e.Len() {
388		return 0, io.ErrShortBuffer
389	}
390	// https://tools.ietf.org/html/rfc7627
391	b[0] = byte(utlsExtensionExtendedMasterSecret >> 8)
392	b[1] = byte(utlsExtensionExtendedMasterSecret)
393	// The length is 0
394	return e.Len(), io.EOF
395}
396
397var extendedMasterSecretLabel = []byte("extended master secret")
398
399// extendedMasterFromPreMasterSecret generates the master secret from the pre-master
400// secret and session hash. See https://tools.ietf.org/html/rfc7627#section-4
401func extendedMasterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret []byte, fh finishedHash) []byte {
402	sessionHash := fh.Sum()
403	masterSecret := make([]byte, masterSecretLength)
404	prfForVersion(version, suite)(masterSecret, preMasterSecret, extendedMasterSecretLabel, sessionHash)
405	return masterSecret
406}
407
408// GREASE stinks with dead parrots, have to be super careful, and, if possible, not include GREASE
409// https://github.com/google/boringssl/blob/1c68fa2350936ca5897a66b430ebaf333a0e43f5/ssl/internal.h
410const (
411	ssl_grease_cipher = iota
412	ssl_grease_group
413	ssl_grease_extension1
414	ssl_grease_extension2
415	ssl_grease_version
416	ssl_grease_ticket_extension
417	ssl_grease_last_index = ssl_grease_ticket_extension
418)
419
420// it is responsibility of user not to generate multiple grease extensions with same value
421type UtlsGREASEExtension struct {
422	Value uint16
423	Body  []byte // in Chrome first grease has empty body, second grease has a single zero byte
424}
425
426func (e *UtlsGREASEExtension) writeToUConn(uc *UConn) error {
427	return nil
428}
429
430// will panic if ssl_grease_last_index[index] is out of bounds.
431func GetBoringGREASEValue(greaseSeed [ssl_grease_last_index]uint16, index int) uint16 {
432	// GREASE value is back from deterministic to random.
433	// https://github.com/google/boringssl/blob/a365138ac60f38b64bfc608b493e0f879845cb88/ssl/handshake_client.c#L530
434	ret := uint16(greaseSeed[index])
435	/* This generates a random value of the form 0xωaωa, for all 0 ≤ ω < 16. */
436	ret = (ret & 0xf0) | 0x0a
437	ret |= ret << 8
438	return ret
439}
440
441func (e *UtlsGREASEExtension) Len() int {
442	return 4 + len(e.Body)
443}
444
445func (e *UtlsGREASEExtension) Read(b []byte) (int, error) {
446	if len(b) < e.Len() {
447		return 0, io.ErrShortBuffer
448	}
449
450	b[0] = byte(e.Value >> 8)
451	b[1] = byte(e.Value)
452	b[2] = byte(len(e.Body) >> 8)
453	b[3] = byte(len(e.Body))
454	if len(e.Body) > 0 {
455		copy(b[4:], e.Body)
456	}
457	return e.Len(), io.EOF
458}
459
460type UtlsPaddingExtension struct {
461	PaddingLen int
462	WillPad    bool // set to false to disable extension
463
464	// Functor for deciding on padding length based on unpadded ClientHello length.
465	// If willPad is false, then this extension should not be included.
466	GetPaddingLen func(clientHelloUnpaddedLen int) (paddingLen int, willPad bool)
467}
468
469func (e *UtlsPaddingExtension) writeToUConn(uc *UConn) error {
470	return nil
471}
472
473func (e *UtlsPaddingExtension) Len() int {
474	if e.WillPad {
475		return 4 + e.PaddingLen
476	} else {
477		return 0
478	}
479}
480
481func (e *UtlsPaddingExtension) Update(clientHelloUnpaddedLen int) {
482	if e.GetPaddingLen != nil {
483		e.PaddingLen, e.WillPad = e.GetPaddingLen(clientHelloUnpaddedLen)
484	}
485}
486
487func (e *UtlsPaddingExtension) Read(b []byte) (int, error) {
488	if !e.WillPad {
489		return 0, io.EOF
490	}
491	if len(b) < e.Len() {
492		return 0, io.ErrShortBuffer
493	}
494	// https://tools.ietf.org/html/rfc7627
495	b[0] = byte(utlsExtensionPadding >> 8)
496	b[1] = byte(utlsExtensionPadding)
497	b[2] = byte(e.PaddingLen >> 8)
498	b[3] = byte(e.PaddingLen)
499	return e.Len(), io.EOF
500}
501
502// https://github.com/google/boringssl/blob/7d7554b6b3c79e707e25521e61e066ce2b996e4c/ssl/t1_lib.c#L2803
503func BoringPaddingStyle(unpaddedLen int) (int, bool) {
504	if unpaddedLen > 0xff && unpaddedLen < 0x200 {
505		paddingLen := 0x200 - unpaddedLen
506		if paddingLen >= 4+1 {
507			paddingLen -= 4
508		} else {
509			paddingLen = 1
510		}
511		return paddingLen, true
512	}
513	return 0, false
514}
515
516/* TLS 1.3 */
517type KeyShareExtension struct {
518	KeyShares []KeyShare
519}
520
521func (e *KeyShareExtension) Len() int {
522	return 4 + 2 + e.keySharesLen()
523}
524
525func (e *KeyShareExtension) keySharesLen() int {
526	extLen := 0
527	for _, ks := range e.KeyShares {
528		extLen += 4 + len(ks.Data)
529	}
530	return extLen
531}
532
533func (e *KeyShareExtension) Read(b []byte) (int, error) {
534	if len(b) < e.Len() {
535		return 0, io.ErrShortBuffer
536	}
537
538	b[0] = byte(extensionKeyShare >> 8)
539	b[1] = byte(extensionKeyShare)
540	keySharesLen := e.keySharesLen()
541	b[2] = byte((keySharesLen + 2) >> 8)
542	b[3] = byte((keySharesLen + 2))
543	b[4] = byte((keySharesLen) >> 8)
544	b[5] = byte((keySharesLen))
545
546	i := 6
547	for _, ks := range e.KeyShares {
548		b[i] = byte(ks.Group >> 8)
549		b[i+1] = byte(ks.Group)
550		b[i+2] = byte(len(ks.Data) >> 8)
551		b[i+3] = byte(len(ks.Data))
552		copy(b[i+4:], ks.Data)
553		i += 4 + len(ks.Data)
554	}
555
556	return e.Len(), io.EOF
557}
558
559func (e *KeyShareExtension) writeToUConn(uc *UConn) error {
560	uc.HandshakeState.Hello.KeyShares = e.KeyShares
561	return nil
562}
563
564type PSKKeyExchangeModesExtension struct {
565	Modes []uint8
566}
567
568func (e *PSKKeyExchangeModesExtension) Len() int {
569	return 4 + 1 + len(e.Modes)
570}
571
572func (e *PSKKeyExchangeModesExtension) Read(b []byte) (int, error) {
573	if len(b) < e.Len() {
574		return 0, io.ErrShortBuffer
575	}
576
577	if len(e.Modes) > 255 {
578		return 0, errors.New("too many PSK Key Exchange modes")
579	}
580
581	b[0] = byte(extensionPSKModes >> 8)
582	b[1] = byte(extensionPSKModes)
583
584	modesLen := len(e.Modes)
585	b[2] = byte((modesLen + 1) >> 8)
586	b[3] = byte((modesLen + 1))
587	b[4] = byte(modesLen)
588
589	if len(e.Modes) > 0 {
590		copy(b[5:], e.Modes)
591	}
592
593	return e.Len(), io.EOF
594}
595
596func (e *PSKKeyExchangeModesExtension) writeToUConn(uc *UConn) error {
597	uc.HandshakeState.Hello.PskModes = e.Modes
598	return nil
599}
600
601type SupportedVersionsExtension struct {
602	Versions []uint16
603}
604
605func (e *SupportedVersionsExtension) writeToUConn(uc *UConn) error {
606	uc.HandshakeState.Hello.SupportedVersions = e.Versions
607	return nil
608}
609
610func (e *SupportedVersionsExtension) Len() int {
611	return 4 + 1 + (2 * len(e.Versions))
612}
613
614func (e *SupportedVersionsExtension) Read(b []byte) (int, error) {
615	if len(b) < e.Len() {
616		return 0, io.ErrShortBuffer
617	}
618	extLen := 2 * len(e.Versions)
619	if extLen > 255 {
620		return 0, errors.New("too many supported versions")
621	}
622
623	b[0] = byte(extensionSupportedVersions >> 8)
624	b[1] = byte(extensionSupportedVersions)
625	b[2] = byte((extLen + 1) >> 8)
626	b[3] = byte((extLen + 1))
627	b[4] = byte(extLen)
628
629	i := 5
630	for _, sv := range e.Versions {
631		b[i] = byte(sv >> 8)
632		b[i+1] = byte(sv)
633		i += 2
634	}
635	return e.Len(), io.EOF
636}
637
638// MUST NOT be part of initial ClientHello
639type CookieExtension struct {
640	Cookie []byte
641}
642
643func (e *CookieExtension) writeToUConn(uc *UConn) error {
644	return nil
645}
646
647func (e *CookieExtension) Len() int {
648	return 4 + len(e.Cookie)
649}
650
651func (e *CookieExtension) Read(b []byte) (int, error) {
652	if len(b) < e.Len() {
653		return 0, io.ErrShortBuffer
654	}
655
656	b[0] = byte(extensionCookie >> 8)
657	b[1] = byte(extensionCookie)
658	b[2] = byte(len(e.Cookie) >> 8)
659	b[3] = byte(len(e.Cookie))
660	if len(e.Cookie) > 0 {
661		copy(b[4:], e.Cookie)
662	}
663	return e.Len(), io.EOF
664}
665
666/*
667FAKE EXTENSIONS
668*/
669
670type FakeChannelIDExtension struct {
671}
672
673func (e *FakeChannelIDExtension) writeToUConn(uc *UConn) error {
674	return nil
675}
676
677func (e *FakeChannelIDExtension) Len() int {
678	return 4
679}
680
681func (e *FakeChannelIDExtension) Read(b []byte) (int, error) {
682	if len(b) < e.Len() {
683		return 0, io.ErrShortBuffer
684	}
685	// https://tools.ietf.org/html/draft-balfanz-tls-channelid-00
686	b[0] = byte(fakeExtensionChannelID >> 8)
687	b[1] = byte(fakeExtensionChannelID & 0xff)
688	// The length is 0
689	return e.Len(), io.EOF
690}
691
692type FakeRecordSizeLimitExtension struct {
693	Limit uint16
694}
695
696func (e *FakeRecordSizeLimitExtension) writeToUConn(uc *UConn) error {
697	return nil
698}
699
700func (e *FakeRecordSizeLimitExtension) Len() int {
701	return 6
702}
703
704func (e *FakeRecordSizeLimitExtension) Read(b []byte) (int, error) {
705	if len(b) < e.Len() {
706		return 0, io.ErrShortBuffer
707	}
708	// https://tools.ietf.org/html/draft-balfanz-tls-channelid-00
709	b[0] = byte(fakeRecordSizeLimit >> 8)
710	b[1] = byte(fakeRecordSizeLimit & 0xff)
711
712	b[2] = byte(0)
713	b[3] = byte(2)
714
715	b[4] = byte(e.Limit >> 8)
716	b[5] = byte(e.Limit & 0xff)
717	return e.Len(), io.EOF
718}
719