1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package qtls
6
7import (
8	"bytes"
9	"crypto/aes"
10	"crypto/cipher"
11	"crypto/hmac"
12	"crypto/sha256"
13	"crypto/subtle"
14	"errors"
15	"io"
16	"time"
17
18	"golang.org/x/crypto/cryptobyte"
19)
20
21// sessionState contains the information that is serialized into a session
22// ticket in order to later resume a connection.
23type sessionState struct {
24	vers         uint16
25	cipherSuite  uint16
26	createdAt    uint64
27	masterSecret []byte // opaque master_secret<1..2^16-1>;
28	// struct { opaque certificate<1..2^24-1> } Certificate;
29	certificates [][]byte // Certificate certificate_list<0..2^24-1>;
30
31	// usedOldKey is true if the ticket from which this session came from
32	// was encrypted with an older key and thus should be refreshed.
33	usedOldKey bool
34}
35
36func (m *sessionState) marshal() []byte {
37	var b cryptobyte.Builder
38	b.AddUint16(m.vers)
39	b.AddUint16(m.cipherSuite)
40	addUint64(&b, m.createdAt)
41	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
42		b.AddBytes(m.masterSecret)
43	})
44	b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
45		for _, cert := range m.certificates {
46			b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
47				b.AddBytes(cert)
48			})
49		}
50	})
51	return b.BytesOrPanic()
52}
53
54func (m *sessionState) unmarshal(data []byte) bool {
55	*m = sessionState{usedOldKey: m.usedOldKey}
56	s := cryptobyte.String(data)
57	if ok := s.ReadUint16(&m.vers) &&
58		s.ReadUint16(&m.cipherSuite) &&
59		readUint64(&s, &m.createdAt) &&
60		readUint16LengthPrefixed(&s, &m.masterSecret) &&
61		len(m.masterSecret) != 0; !ok {
62		return false
63	}
64	var certList cryptobyte.String
65	if !s.ReadUint24LengthPrefixed(&certList) {
66		return false
67	}
68	for !certList.Empty() {
69		var cert []byte
70		if !readUint24LengthPrefixed(&certList, &cert) {
71			return false
72		}
73		m.certificates = append(m.certificates, cert)
74	}
75	return s.Empty()
76}
77
78// sessionStateTLS13 is the content of a TLS 1.3 session ticket. Its first
79// version (revision = 0) doesn't carry any of the information needed for 0-RTT
80// validation and the nonce is always empty.
81// version (revision = 1) carries the max_early_data_size sent in the ticket.
82// version (revision = 2) carries the ALPN sent in the ticket.
83type sessionStateTLS13 struct {
84	// uint8 version  = 0x0304;
85	// uint8 revision = 2;
86	cipherSuite      uint16
87	createdAt        uint64
88	resumptionSecret []byte      // opaque resumption_master_secret<1..2^8-1>;
89	certificate      Certificate // CertificateEntry certificate_list<0..2^24-1>;
90	maxEarlyData     uint32
91	alpn             string
92
93	appData []byte
94}
95
96func (m *sessionStateTLS13) marshal() []byte {
97	var b cryptobyte.Builder
98	b.AddUint16(VersionTLS13)
99	b.AddUint8(2) // revision
100	b.AddUint16(m.cipherSuite)
101	addUint64(&b, m.createdAt)
102	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
103		b.AddBytes(m.resumptionSecret)
104	})
105	marshalCertificate(&b, m.certificate)
106	b.AddUint32(m.maxEarlyData)
107	b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
108		b.AddBytes([]byte(m.alpn))
109	})
110	b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
111		b.AddBytes(m.appData)
112	})
113	return b.BytesOrPanic()
114}
115
116func (m *sessionStateTLS13) unmarshal(data []byte) bool {
117	*m = sessionStateTLS13{}
118	s := cryptobyte.String(data)
119	var version uint16
120	var revision uint8
121	var alpn []byte
122	ret := s.ReadUint16(&version) &&
123		version == VersionTLS13 &&
124		s.ReadUint8(&revision) &&
125		revision == 2 &&
126		s.ReadUint16(&m.cipherSuite) &&
127		readUint64(&s, &m.createdAt) &&
128		readUint8LengthPrefixed(&s, &m.resumptionSecret) &&
129		len(m.resumptionSecret) != 0 &&
130		unmarshalCertificate(&s, &m.certificate) &&
131		s.ReadUint32(&m.maxEarlyData) &&
132		readUint8LengthPrefixed(&s, &alpn) &&
133		readUint16LengthPrefixed(&s, &m.appData) &&
134		s.Empty()
135	m.alpn = string(alpn)
136	return ret
137}
138
139func (c *Conn) encryptTicket(state []byte) ([]byte, error) {
140	if len(c.ticketKeys) == 0 {
141		return nil, errors.New("tls: internal error: session ticket keys unavailable")
142	}
143
144	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(state)+sha256.Size)
145	keyName := encrypted[:ticketKeyNameLen]
146	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
147	macBytes := encrypted[len(encrypted)-sha256.Size:]
148
149	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
150		return nil, err
151	}
152	key := c.ticketKeys[0]
153	copy(keyName, key.keyName[:])
154	block, err := aes.NewCipher(key.aesKey[:])
155	if err != nil {
156		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
157	}
158	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], state)
159
160	mac := hmac.New(sha256.New, key.hmacKey[:])
161	mac.Write(encrypted[:len(encrypted)-sha256.Size])
162	mac.Sum(macBytes[:0])
163
164	return encrypted, nil
165}
166
167func (c *Conn) decryptTicket(encrypted []byte) (plaintext []byte, usedOldKey bool) {
168	if len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
169		return nil, false
170	}
171
172	keyName := encrypted[:ticketKeyNameLen]
173	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
174	macBytes := encrypted[len(encrypted)-sha256.Size:]
175	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
176
177	keyIndex := -1
178	for i, candidateKey := range c.ticketKeys {
179		if bytes.Equal(keyName, candidateKey.keyName[:]) {
180			keyIndex = i
181			break
182		}
183	}
184	if keyIndex == -1 {
185		return nil, false
186	}
187	key := &c.ticketKeys[keyIndex]
188
189	mac := hmac.New(sha256.New, key.hmacKey[:])
190	mac.Write(encrypted[:len(encrypted)-sha256.Size])
191	expected := mac.Sum(nil)
192
193	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
194		return nil, false
195	}
196
197	block, err := aes.NewCipher(key.aesKey[:])
198	if err != nil {
199		return nil, false
200	}
201	plaintext = make([]byte, len(ciphertext))
202	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
203
204	return plaintext, keyIndex > 0
205}
206
207func (c *Conn) getSessionTicketMsg(appData []byte) (*newSessionTicketMsgTLS13, error) {
208	m := new(newSessionTicketMsgTLS13)
209
210	var certsFromClient [][]byte
211	for _, cert := range c.peerCertificates {
212		certsFromClient = append(certsFromClient, cert.Raw)
213	}
214	state := sessionStateTLS13{
215		cipherSuite:      c.cipherSuite,
216		createdAt:        uint64(c.config.time().Unix()),
217		resumptionSecret: c.resumptionSecret,
218		certificate: Certificate{
219			Certificate:                 certsFromClient,
220			OCSPStaple:                  c.ocspResponse,
221			SignedCertificateTimestamps: c.scts,
222		},
223		appData: appData,
224		alpn:    c.clientProtocol,
225	}
226	if c.extraConfig != nil {
227		state.maxEarlyData = c.extraConfig.MaxEarlyData
228	}
229	var err error
230	m.label, err = c.encryptTicket(state.marshal())
231	if err != nil {
232		return nil, err
233	}
234	m.lifetime = uint32(maxSessionTicketLifetime / time.Second)
235	if c.extraConfig != nil {
236		m.maxEarlyData = c.extraConfig.MaxEarlyData
237	}
238	return m, nil
239}
240
241// GetSessionTicket generates a new session ticket.
242// It should only be called after the handshake completes.
243// It can only be used for servers, and only if the alternative record layer is set.
244// The ticket may be nil if config.SessionTicketsDisabled is set,
245// or if the client isn't able to receive session tickets.
246func (c *Conn) GetSessionTicket(appData []byte) ([]byte, error) {
247	if c.isClient || !c.handshakeComplete() || c.extraConfig == nil || c.extraConfig.AlternativeRecordLayer == nil {
248		return nil, errors.New("GetSessionTicket is only valid for servers after completion of the handshake, and if an alternative record layer is set.")
249	}
250	if c.config.SessionTicketsDisabled {
251		return nil, nil
252	}
253
254	m, err := c.getSessionTicketMsg(appData)
255	if err != nil {
256		return nil, err
257	}
258	return m.marshal(), nil
259}
260