1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"crypto/aes"
9	"crypto/cipher"
10	"crypto/hmac"
11	"crypto/sha256"
12	"crypto/subtle"
13	"errors"
14	"io"
15	"time"
16)
17
18// sessionState contains the information that is serialized into a session
19// ticket in order to later resume a connection.
20type sessionState struct {
21	vers                     uint16
22	cipherSuite              uint16
23	masterSecret             []byte
24	handshakeHash            []byte
25	certificates             [][]byte
26	extendedMasterSecret     bool
27	earlyALPN                []byte
28	ticketCreationTime       time.Time
29	ticketExpiration         time.Time
30	ticketFlags              uint32
31	ticketAgeAdd             uint32
32	hasApplicationSettings   bool
33	localApplicationSettings []byte
34	peerApplicationSettings  []byte
35}
36
37func (s *sessionState) marshal() []byte {
38	msg := newByteBuilder()
39	msg.addU16(s.vers)
40	msg.addU16(s.cipherSuite)
41	masterSecret := msg.addU16LengthPrefixed()
42	masterSecret.addBytes(s.masterSecret)
43	handshakeHash := msg.addU16LengthPrefixed()
44	handshakeHash.addBytes(s.handshakeHash)
45	msg.addU16(uint16(len(s.certificates)))
46	for _, cert := range s.certificates {
47		certMsg := msg.addU32LengthPrefixed()
48		certMsg.addBytes(cert)
49	}
50
51	if s.extendedMasterSecret {
52		msg.addU8(1)
53	} else {
54		msg.addU8(0)
55	}
56
57	if s.vers >= VersionTLS13 {
58		msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
59		msg.addU64(uint64(s.ticketExpiration.UnixNano()))
60		msg.addU32(s.ticketFlags)
61		msg.addU32(s.ticketAgeAdd)
62	}
63
64	earlyALPN := msg.addU16LengthPrefixed()
65	earlyALPN.addBytes(s.earlyALPN)
66
67	if s.hasApplicationSettings {
68		msg.addU8(1)
69		msg.addU16LengthPrefixed().addBytes(s.localApplicationSettings)
70		msg.addU16LengthPrefixed().addBytes(s.peerApplicationSettings)
71	} else {
72		msg.addU8(0)
73	}
74
75	return msg.finish()
76}
77
78func readBool(reader *byteReader, out *bool) bool {
79	var value uint8
80	if !reader.readU8(&value) {
81		return false
82	}
83	if value == 0 {
84		*out = false
85		return true
86	}
87	if value == 1 {
88		*out = true
89		return true
90	}
91	return false
92}
93
94func (s *sessionState) unmarshal(data []byte) bool {
95	reader := byteReader(data)
96	var numCerts uint16
97	if !reader.readU16(&s.vers) ||
98		!reader.readU16(&s.cipherSuite) ||
99		!reader.readU16LengthPrefixedBytes(&s.masterSecret) ||
100		!reader.readU16LengthPrefixedBytes(&s.handshakeHash) ||
101		!reader.readU16(&numCerts) {
102		return false
103	}
104
105	s.certificates = make([][]byte, int(numCerts))
106	for i := range s.certificates {
107		if !reader.readU32LengthPrefixedBytes(&s.certificates[i]) {
108			return false
109		}
110	}
111
112	if !readBool(&reader, &s.extendedMasterSecret) {
113		return false
114	}
115
116	if s.vers >= VersionTLS13 {
117		var ticketCreationTime, ticketExpiration uint64
118		if !reader.readU64(&ticketCreationTime) ||
119			!reader.readU64(&ticketExpiration) ||
120			!reader.readU32(&s.ticketFlags) ||
121			!reader.readU32(&s.ticketAgeAdd) {
122			return false
123		}
124		s.ticketCreationTime = time.Unix(0, int64(ticketCreationTime))
125		s.ticketExpiration = time.Unix(0, int64(ticketExpiration))
126	}
127
128	if !reader.readU16LengthPrefixedBytes(&s.earlyALPN) ||
129		!readBool(&reader, &s.hasApplicationSettings) {
130		return false
131	}
132
133	if s.hasApplicationSettings {
134		if !reader.readU16LengthPrefixedBytes(&s.localApplicationSettings) ||
135			!reader.readU16LengthPrefixedBytes(&s.peerApplicationSettings) {
136			return false
137		}
138	}
139
140	if len(reader) > 0 {
141		return false
142	}
143
144	return true
145}
146
147func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
148	serialized := state.marshal()
149	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
150	iv := encrypted[:aes.BlockSize]
151	macBytes := encrypted[len(encrypted)-sha256.Size:]
152
153	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
154		return nil, err
155	}
156	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
157	if err != nil {
158		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
159	}
160	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
161
162	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
163	mac.Write(encrypted[:len(encrypted)-sha256.Size])
164	mac.Sum(macBytes[:0])
165
166	return encrypted, nil
167}
168
169func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
170	if len(encrypted) < aes.BlockSize+sha256.Size {
171		return nil, false
172	}
173
174	iv := encrypted[:aes.BlockSize]
175	macBytes := encrypted[len(encrypted)-sha256.Size:]
176
177	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
178	mac.Write(encrypted[:len(encrypted)-sha256.Size])
179	expected := mac.Sum(nil)
180
181	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
182		return nil, false
183	}
184
185	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
186	if err != nil {
187		return nil, false
188	}
189	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
190	plaintext := make([]byte, len(ciphertext))
191	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
192
193	state := new(sessionState)
194	ok := state.unmarshal(plaintext)
195	return state, ok
196}
197