1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"crypto"
10	"crypto/hmac"
11	"crypto/rsa"
12	"errors"
13	"hash"
14	"sync/atomic"
15	"time"
16)
17
18type clientHandshakeStateTLS13 struct {
19	c           *Conn
20	serverHello *serverHelloMsg
21	hello       *clientHelloMsg
22	ecdheParams ecdheParameters
23
24	session     *ClientSessionState
25	earlySecret []byte
26	binderKey   []byte
27
28	certReq       *certificateRequestMsgTLS13
29	usingPSK      bool
30	sentDummyCCS  bool
31	suite         *cipherSuiteTLS13
32	transcript    hash.Hash
33	masterSecret  []byte
34	trafficSecret []byte // client_application_traffic_secret_0
35}
36
37// handshake requires hs.c, hs.hello, hs.serverHello, hs.ecdheParams, and,
38// optionally, hs.session, hs.earlySecret and hs.binderKey to be set.
39func (hs *clientHandshakeStateTLS13) handshake() error {
40	c := hs.c
41
42	// The server must not select TLS 1.3 in a renegotiation. See RFC 8446,
43	// sections 4.1.2 and 4.1.3.
44	if c.handshakes > 0 {
45		c.sendAlert(alertProtocolVersion)
46		return errors.New("tls: server selected TLS 1.3 in a renegotiation")
47	}
48
49	// Consistency check on the presence of a keyShare and its parameters.
50	if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 {
51		return c.sendAlert(alertInternalError)
52	}
53
54	if err := hs.checkServerHelloOrHRR(); err != nil {
55		return err
56	}
57
58	hs.transcript = hs.suite.hash.New()
59	hs.transcript.Write(hs.hello.marshal())
60
61	if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
62		if err := hs.sendDummyChangeCipherSpec(); err != nil {
63			return err
64		}
65		if err := hs.processHelloRetryRequest(); err != nil {
66			return err
67		}
68	}
69
70	hs.transcript.Write(hs.serverHello.marshal())
71
72	c.buffering = true
73	if err := hs.processServerHello(); err != nil {
74		return err
75	}
76	if err := hs.sendDummyChangeCipherSpec(); err != nil {
77		return err
78	}
79	if err := hs.establishHandshakeKeys(); err != nil {
80		return err
81	}
82	if err := hs.readServerParameters(); err != nil {
83		return err
84	}
85	if err := hs.readServerCertificate(); err != nil {
86		return err
87	}
88	if err := hs.readServerFinished(); err != nil {
89		return err
90	}
91	if err := hs.sendClientCertificate(); err != nil {
92		return err
93	}
94	if err := hs.sendClientFinished(); err != nil {
95		return err
96	}
97	if _, err := c.flush(); err != nil {
98		return err
99	}
100
101	atomic.StoreUint32(&c.handshakeStatus, 1)
102
103	return nil
104}
105
106// checkServerHelloOrHRR does validity checks that apply to both ServerHello and
107// HelloRetryRequest messages. It sets hs.suite.
108func (hs *clientHandshakeStateTLS13) checkServerHelloOrHRR() error {
109	c := hs.c
110
111	if hs.serverHello.supportedVersion == 0 {
112		c.sendAlert(alertMissingExtension)
113		return errors.New("tls: server selected TLS 1.3 using the legacy version field")
114	}
115
116	if hs.serverHello.supportedVersion != VersionTLS13 {
117		c.sendAlert(alertIllegalParameter)
118		return errors.New("tls: server selected an invalid version after a HelloRetryRequest")
119	}
120
121	if hs.serverHello.vers != VersionTLS12 {
122		c.sendAlert(alertIllegalParameter)
123		return errors.New("tls: server sent an incorrect legacy version")
124	}
125
126	if hs.serverHello.ocspStapling ||
127		hs.serverHello.ticketSupported ||
128		hs.serverHello.secureRenegotiationSupported ||
129		len(hs.serverHello.secureRenegotiation) != 0 ||
130		len(hs.serverHello.alpnProtocol) != 0 ||
131		len(hs.serverHello.scts) != 0 {
132		c.sendAlert(alertUnsupportedExtension)
133		return errors.New("tls: server sent a ServerHello extension forbidden in TLS 1.3")
134	}
135
136	if !bytes.Equal(hs.hello.sessionId, hs.serverHello.sessionId) {
137		c.sendAlert(alertIllegalParameter)
138		return errors.New("tls: server did not echo the legacy session ID")
139	}
140
141	if hs.serverHello.compressionMethod != compressionNone {
142		c.sendAlert(alertIllegalParameter)
143		return errors.New("tls: server selected unsupported compression format")
144	}
145
146	selectedSuite := mutualCipherSuiteTLS13(hs.hello.cipherSuites, hs.serverHello.cipherSuite)
147	if hs.suite != nil && selectedSuite != hs.suite {
148		c.sendAlert(alertIllegalParameter)
149		return errors.New("tls: server changed cipher suite after a HelloRetryRequest")
150	}
151	if selectedSuite == nil {
152		c.sendAlert(alertIllegalParameter)
153		return errors.New("tls: server chose an unconfigured cipher suite")
154	}
155	hs.suite = selectedSuite
156	c.cipherSuite = hs.suite.id
157
158	return nil
159}
160
161// sendDummyChangeCipherSpec sends a ChangeCipherSpec record for compatibility
162// with middleboxes that didn't implement TLS correctly. See RFC 8446, Appendix D.4.
163func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
164	if hs.sentDummyCCS {
165		return nil
166	}
167	hs.sentDummyCCS = true
168
169	_, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
170	return err
171}
172
173// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
174// resends hs.hello, and reads the new ServerHello into hs.serverHello.
175func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
176	c := hs.c
177
178	// The first ClientHello gets double-hashed into the transcript upon a
179	// HelloRetryRequest. See RFC 8446, Section 4.4.1.
180	chHash := hs.transcript.Sum(nil)
181	hs.transcript.Reset()
182	hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
183	hs.transcript.Write(chHash)
184	hs.transcript.Write(hs.serverHello.marshal())
185
186	if hs.serverHello.serverShare.group != 0 {
187		c.sendAlert(alertDecodeError)
188		return errors.New("tls: received malformed key_share extension")
189	}
190
191	curveID := hs.serverHello.selectedGroup
192	if curveID == 0 {
193		c.sendAlert(alertMissingExtension)
194		return errors.New("tls: received HelloRetryRequest without selected group")
195	}
196	curveOK := false
197	for _, id := range hs.hello.supportedCurves {
198		if id == curveID {
199			curveOK = true
200			break
201		}
202	}
203	if !curveOK {
204		c.sendAlert(alertIllegalParameter)
205		return errors.New("tls: server selected unsupported group")
206	}
207	if hs.ecdheParams.CurveID() == curveID {
208		c.sendAlert(alertIllegalParameter)
209		return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
210	}
211	if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok {
212		c.sendAlert(alertInternalError)
213		return errors.New("tls: CurvePreferences includes unsupported curve")
214	}
215	params, err := generateECDHEParameters(c.config.rand(), curveID)
216	if err != nil {
217		c.sendAlert(alertInternalError)
218		return err
219	}
220	hs.ecdheParams = params
221	hs.hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}}
222
223	hs.hello.cookie = hs.serverHello.cookie
224
225	hs.hello.raw = nil
226	if len(hs.hello.pskIdentities) > 0 {
227		pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
228		if pskSuite == nil {
229			return c.sendAlert(alertInternalError)
230		}
231		if pskSuite.hash == hs.suite.hash {
232			// Update binders and obfuscated_ticket_age.
233			ticketAge := uint32(c.config.time().Sub(hs.session.receivedAt) / time.Millisecond)
234			hs.hello.pskIdentities[0].obfuscatedTicketAge = ticketAge + hs.session.ageAdd
235
236			transcript := hs.suite.hash.New()
237			transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
238			transcript.Write(chHash)
239			transcript.Write(hs.serverHello.marshal())
240			transcript.Write(hs.hello.marshalWithoutBinders())
241			pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
242			hs.hello.updateBinders(pskBinders)
243		} else {
244			// Server selected a cipher suite incompatible with the PSK.
245			hs.hello.pskIdentities = nil
246			hs.hello.pskBinders = nil
247		}
248	}
249
250	hs.transcript.Write(hs.hello.marshal())
251	if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
252		return err
253	}
254
255	msg, err := c.readHandshake()
256	if err != nil {
257		return err
258	}
259
260	serverHello, ok := msg.(*serverHelloMsg)
261	if !ok {
262		c.sendAlert(alertUnexpectedMessage)
263		return unexpectedMessageError(serverHello, msg)
264	}
265	hs.serverHello = serverHello
266
267	if err := hs.checkServerHelloOrHRR(); err != nil {
268		return err
269	}
270
271	return nil
272}
273
274func (hs *clientHandshakeStateTLS13) processServerHello() error {
275	c := hs.c
276
277	if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
278		c.sendAlert(alertUnexpectedMessage)
279		return errors.New("tls: server sent two HelloRetryRequest messages")
280	}
281
282	if len(hs.serverHello.cookie) != 0 {
283		c.sendAlert(alertUnsupportedExtension)
284		return errors.New("tls: server sent a cookie in a normal ServerHello")
285	}
286
287	if hs.serverHello.selectedGroup != 0 {
288		c.sendAlert(alertDecodeError)
289		return errors.New("tls: malformed key_share extension")
290	}
291
292	if hs.serverHello.serverShare.group == 0 {
293		c.sendAlert(alertIllegalParameter)
294		return errors.New("tls: server did not send a key share")
295	}
296	if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() {
297		c.sendAlert(alertIllegalParameter)
298		return errors.New("tls: server selected unsupported group")
299	}
300
301	if !hs.serverHello.selectedIdentityPresent {
302		return nil
303	}
304
305	if int(hs.serverHello.selectedIdentity) >= len(hs.hello.pskIdentities) {
306		c.sendAlert(alertIllegalParameter)
307		return errors.New("tls: server selected an invalid PSK")
308	}
309
310	if len(hs.hello.pskIdentities) != 1 || hs.session == nil {
311		return c.sendAlert(alertInternalError)
312	}
313	pskSuite := cipherSuiteTLS13ByID(hs.session.cipherSuite)
314	if pskSuite == nil {
315		return c.sendAlert(alertInternalError)
316	}
317	if pskSuite.hash != hs.suite.hash {
318		c.sendAlert(alertIllegalParameter)
319		return errors.New("tls: server selected an invalid PSK and cipher suite pair")
320	}
321
322	hs.usingPSK = true
323	c.didResume = true
324	c.peerCertificates = hs.session.serverCertificates
325	c.verifiedChains = hs.session.verifiedChains
326	return nil
327}
328
329func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
330	c := hs.c
331
332	sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data)
333	if sharedKey == nil {
334		c.sendAlert(alertIllegalParameter)
335		return errors.New("tls: invalid server key share")
336	}
337
338	earlySecret := hs.earlySecret
339	if !hs.usingPSK {
340		earlySecret = hs.suite.extract(nil, nil)
341	}
342	handshakeSecret := hs.suite.extract(sharedKey,
343		hs.suite.deriveSecret(earlySecret, "derived", nil))
344
345	clientSecret := hs.suite.deriveSecret(handshakeSecret,
346		clientHandshakeTrafficLabel, hs.transcript)
347	c.out.setTrafficSecret(hs.suite, clientSecret)
348	serverSecret := hs.suite.deriveSecret(handshakeSecret,
349		serverHandshakeTrafficLabel, hs.transcript)
350	c.in.setTrafficSecret(hs.suite, serverSecret)
351
352	err := c.config.writeKeyLog(keyLogLabelClientHandshake, hs.hello.random, clientSecret)
353	if err != nil {
354		c.sendAlert(alertInternalError)
355		return err
356	}
357	err = c.config.writeKeyLog(keyLogLabelServerHandshake, hs.hello.random, serverSecret)
358	if err != nil {
359		c.sendAlert(alertInternalError)
360		return err
361	}
362
363	hs.masterSecret = hs.suite.extract(nil,
364		hs.suite.deriveSecret(handshakeSecret, "derived", nil))
365
366	return nil
367}
368
369func (hs *clientHandshakeStateTLS13) readServerParameters() error {
370	c := hs.c
371
372	msg, err := c.readHandshake()
373	if err != nil {
374		return err
375	}
376
377	encryptedExtensions, ok := msg.(*encryptedExtensionsMsg)
378	if !ok {
379		c.sendAlert(alertUnexpectedMessage)
380		return unexpectedMessageError(encryptedExtensions, msg)
381	}
382	hs.transcript.Write(encryptedExtensions.marshal())
383
384	if len(encryptedExtensions.alpnProtocol) != 0 && len(hs.hello.alpnProtocols) == 0 {
385		c.sendAlert(alertUnsupportedExtension)
386		return errors.New("tls: server advertised unrequested ALPN extension")
387	}
388	c.clientProtocol = encryptedExtensions.alpnProtocol
389
390	return nil
391}
392
393func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
394	c := hs.c
395
396	// Either a PSK or a certificate is always used, but not both.
397	// See RFC 8446, Section 4.1.1.
398	if hs.usingPSK {
399		return nil
400	}
401
402	msg, err := c.readHandshake()
403	if err != nil {
404		return err
405	}
406
407	certReq, ok := msg.(*certificateRequestMsgTLS13)
408	if ok {
409		hs.transcript.Write(certReq.marshal())
410
411		hs.certReq = certReq
412
413		msg, err = c.readHandshake()
414		if err != nil {
415			return err
416		}
417	}
418
419	certMsg, ok := msg.(*certificateMsgTLS13)
420	if !ok {
421		c.sendAlert(alertUnexpectedMessage)
422		return unexpectedMessageError(certMsg, msg)
423	}
424	if len(certMsg.certificate.Certificate) == 0 {
425		c.sendAlert(alertDecodeError)
426		return errors.New("tls: received empty certificates message")
427	}
428	hs.transcript.Write(certMsg.marshal())
429
430	c.scts = certMsg.certificate.SignedCertificateTimestamps
431	c.ocspResponse = certMsg.certificate.OCSPStaple
432
433	if err := c.verifyServerCertificate(certMsg.certificate.Certificate); err != nil {
434		return err
435	}
436
437	msg, err = c.readHandshake()
438	if err != nil {
439		return err
440	}
441
442	certVerify, ok := msg.(*certificateVerifyMsg)
443	if !ok {
444		c.sendAlert(alertUnexpectedMessage)
445		return unexpectedMessageError(certVerify, msg)
446	}
447
448	// See RFC 8446, Section 4.4.3.
449	if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms) {
450		c.sendAlert(alertIllegalParameter)
451		return errors.New("tls: certificate used with invalid signature algorithm")
452	}
453	sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
454	if err != nil {
455		return c.sendAlert(alertInternalError)
456	}
457	if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
458		c.sendAlert(alertIllegalParameter)
459		return errors.New("tls: certificate used with invalid signature algorithm")
460	}
461	signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
462	if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
463		sigHash, signed, certVerify.signature); err != nil {
464		c.sendAlert(alertDecryptError)
465		return errors.New("tls: invalid signature by the server certificate: " + err.Error())
466	}
467
468	hs.transcript.Write(certVerify.marshal())
469
470	return nil
471}
472
473func (hs *clientHandshakeStateTLS13) readServerFinished() error {
474	c := hs.c
475
476	msg, err := c.readHandshake()
477	if err != nil {
478		return err
479	}
480
481	finished, ok := msg.(*finishedMsg)
482	if !ok {
483		c.sendAlert(alertUnexpectedMessage)
484		return unexpectedMessageError(finished, msg)
485	}
486
487	expectedMAC := hs.suite.finishedHash(c.in.trafficSecret, hs.transcript)
488	if !hmac.Equal(expectedMAC, finished.verifyData) {
489		c.sendAlert(alertDecryptError)
490		return errors.New("tls: invalid server finished hash")
491	}
492
493	hs.transcript.Write(finished.marshal())
494
495	// Derive secrets that take context through the server Finished.
496
497	hs.trafficSecret = hs.suite.deriveSecret(hs.masterSecret,
498		clientApplicationTrafficLabel, hs.transcript)
499	serverSecret := hs.suite.deriveSecret(hs.masterSecret,
500		serverApplicationTrafficLabel, hs.transcript)
501	c.in.setTrafficSecret(hs.suite, serverSecret)
502
503	err = c.config.writeKeyLog(keyLogLabelClientTraffic, hs.hello.random, hs.trafficSecret)
504	if err != nil {
505		c.sendAlert(alertInternalError)
506		return err
507	}
508	err = c.config.writeKeyLog(keyLogLabelServerTraffic, hs.hello.random, serverSecret)
509	if err != nil {
510		c.sendAlert(alertInternalError)
511		return err
512	}
513
514	c.ekm = hs.suite.exportKeyingMaterial(hs.masterSecret, hs.transcript)
515
516	return nil
517}
518
519func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
520	c := hs.c
521
522	if hs.certReq == nil {
523		return nil
524	}
525
526	cert, err := c.getClientCertificate(&CertificateRequestInfo{
527		AcceptableCAs:    hs.certReq.certificateAuthorities,
528		SignatureSchemes: hs.certReq.supportedSignatureAlgorithms,
529		Version:          c.vers,
530	})
531	if err != nil {
532		return err
533	}
534
535	certMsg := new(certificateMsgTLS13)
536
537	certMsg.certificate = *cert
538	certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
539	certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
540
541	hs.transcript.Write(certMsg.marshal())
542	if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
543		return err
544	}
545
546	// If we sent an empty certificate message, skip the CertificateVerify.
547	if len(cert.Certificate) == 0 {
548		return nil
549	}
550
551	certVerifyMsg := new(certificateVerifyMsg)
552	certVerifyMsg.hasSignatureAlgorithm = true
553
554	certVerifyMsg.signatureAlgorithm, err = selectSignatureScheme(c.vers, cert, hs.certReq.supportedSignatureAlgorithms)
555	if err != nil {
556		// getClientCertificate returned a certificate incompatible with the
557		// CertificateRequestInfo supported signature algorithms.
558		c.sendAlert(alertHandshakeFailure)
559		return err
560	}
561
562	sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerifyMsg.signatureAlgorithm)
563	if err != nil {
564		return c.sendAlert(alertInternalError)
565	}
566
567	signed := signedMessage(sigHash, clientSignatureContext, hs.transcript)
568	signOpts := crypto.SignerOpts(sigHash)
569	if sigType == signatureRSAPSS {
570		signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
571	}
572	sig, err := cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), signed, signOpts)
573	if err != nil {
574		c.sendAlert(alertInternalError)
575		return errors.New("tls: failed to sign handshake: " + err.Error())
576	}
577	certVerifyMsg.signature = sig
578
579	hs.transcript.Write(certVerifyMsg.marshal())
580	if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
581		return err
582	}
583
584	return nil
585}
586
587func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
588	c := hs.c
589
590	finished := &finishedMsg{
591		verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
592	}
593
594	hs.transcript.Write(finished.marshal())
595	if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
596		return err
597	}
598
599	c.out.setTrafficSecret(hs.suite, hs.trafficSecret)
600
601	if !c.config.SessionTicketsDisabled && c.config.ClientSessionCache != nil {
602		c.resumptionSecret = hs.suite.deriveSecret(hs.masterSecret,
603			resumptionLabel, hs.transcript)
604	}
605
606	return nil
607}
608
609func (c *Conn) handleNewSessionTicket(msg *newSessionTicketMsgTLS13) error {
610	if !c.isClient {
611		c.sendAlert(alertUnexpectedMessage)
612		return errors.New("tls: received new session ticket from a client")
613	}
614
615	if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
616		return nil
617	}
618
619	// See RFC 8446, Section 4.6.1.
620	if msg.lifetime == 0 {
621		return nil
622	}
623	lifetime := time.Duration(msg.lifetime) * time.Second
624	if lifetime > maxSessionTicketLifetime {
625		c.sendAlert(alertIllegalParameter)
626		return errors.New("tls: received a session ticket with invalid lifetime")
627	}
628
629	cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
630	if cipherSuite == nil || c.resumptionSecret == nil {
631		return c.sendAlert(alertInternalError)
632	}
633
634	// Save the resumption_master_secret and nonce instead of deriving the PSK
635	// to do the least amount of work on NewSessionTicket messages before we
636	// know if the ticket will be used. Forward secrecy of resumed connections
637	// is guaranteed by the requirement for pskModeDHE.
638	session := &ClientSessionState{
639		sessionTicket:      msg.label,
640		vers:               c.vers,
641		cipherSuite:        c.cipherSuite,
642		masterSecret:       c.resumptionSecret,
643		serverCertificates: c.peerCertificates,
644		verifiedChains:     c.verifiedChains,
645		receivedAt:         c.config.time(),
646		nonce:              msg.nonce,
647		useBy:              c.config.time().Add(lifetime),
648		ageAdd:             msg.ageAdd,
649	}
650
651	cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
652	c.config.ClientSessionCache.Put(cacheKey, session)
653
654	return nil
655}
656