1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// TLS low level connection and record layer
6
7package tls
8
9import (
10	"bytes"
11	"context"
12	"crypto/cipher"
13	"crypto/subtle"
14	"crypto/x509"
15	"errors"
16	"fmt"
17	"hash"
18	"io"
19	"net"
20	"sync"
21	"sync/atomic"
22	"time"
23)
24
25// A Conn represents a secured connection.
26// It implements the net.Conn interface.
27type Conn struct {
28	// constant
29	conn        net.Conn
30	isClient    bool
31	handshakeFn func(context.Context) error // (*Conn).clientHandshake or serverHandshake
32
33	// handshakeStatus is 1 if the connection is currently transferring
34	// application data (i.e. is not currently processing a handshake).
35	// This field is only to be accessed with sync/atomic.
36	handshakeStatus uint32
37	// constant after handshake; protected by handshakeMutex
38	handshakeMutex sync.Mutex
39	handshakeErr   error   // error resulting from handshake
40	vers           uint16  // TLS version
41	haveVers       bool    // version has been negotiated
42	config         *Config // configuration passed to constructor
43	// handshakes counts the number of handshakes performed on the
44	// connection so far. If renegotiation is disabled then this is either
45	// zero or one.
46	handshakes       int
47	didResume        bool // whether this connection was a session resumption
48	cipherSuite      uint16
49	ocspResponse     []byte   // stapled OCSP response
50	scts             [][]byte // signed certificate timestamps from server
51	peerCertificates []*x509.Certificate
52	// verifiedChains contains the certificate chains that we built, as
53	// opposed to the ones presented by the server.
54	verifiedChains [][]*x509.Certificate
55	// serverName contains the server name indicated by the client, if any.
56	serverName string
57	// secureRenegotiation is true if the server echoed the secure
58	// renegotiation extension. (This is meaningless as a server because
59	// renegotiation is not supported in that case.)
60	secureRenegotiation bool
61	// ekm is a closure for exporting keying material.
62	ekm func(label string, context []byte, length int) ([]byte, error)
63	// resumptionSecret is the resumption_master_secret for handling
64	// NewSessionTicket messages. nil if config.SessionTicketsDisabled.
65	resumptionSecret []byte
66
67	// ticketKeys is the set of active session ticket keys for this
68	// connection. The first one is used to encrypt new tickets and
69	// all are tried to decrypt tickets.
70	ticketKeys []ticketKey
71
72	// clientFinishedIsFirst is true if the client sent the first Finished
73	// message during the most recent handshake. This is recorded because
74	// the first transmitted Finished message is the tls-unique
75	// channel-binding value.
76	clientFinishedIsFirst bool
77
78	// closeNotifyErr is any error from sending the alertCloseNotify record.
79	closeNotifyErr error
80	// closeNotifySent is true if the Conn attempted to send an
81	// alertCloseNotify record.
82	closeNotifySent bool
83
84	// clientFinished and serverFinished contain the Finished message sent
85	// by the client or server in the most recent handshake. This is
86	// retained to support the renegotiation extension and tls-unique
87	// channel-binding.
88	clientFinished [12]byte
89	serverFinished [12]byte
90
91	// clientProtocol is the negotiated ALPN protocol.
92	clientProtocol string
93
94	// input/output
95	in, out   halfConn
96	rawInput  bytes.Buffer // raw input, starting with a record header
97	input     bytes.Reader // application data waiting to be read, from rawInput.Next
98	hand      bytes.Buffer // handshake data waiting to be read
99	buffering bool         // whether records are buffered in sendBuf
100	sendBuf   []byte       // a buffer of records waiting to be sent
101
102	// bytesSent counts the bytes of application data sent.
103	// packetsSent counts packets.
104	bytesSent   int64
105	packetsSent int64
106
107	// retryCount counts the number of consecutive non-advancing records
108	// received by Conn.readRecord. That is, records that neither advance the
109	// handshake, nor deliver application data. Protected by in.Mutex.
110	retryCount int
111
112	// activeCall is an atomic int32; the low bit is whether Close has
113	// been called. the rest of the bits are the number of goroutines
114	// in Conn.Write.
115	activeCall int32
116
117	tmp [16]byte
118}
119
120// Access to net.Conn methods.
121// Cannot just embed net.Conn because that would
122// export the struct field too.
123
124// LocalAddr returns the local network address.
125func (c *Conn) LocalAddr() net.Addr {
126	return c.conn.LocalAddr()
127}
128
129// RemoteAddr returns the remote network address.
130func (c *Conn) RemoteAddr() net.Addr {
131	return c.conn.RemoteAddr()
132}
133
134// SetDeadline sets the read and write deadlines associated with the connection.
135// A zero value for t means Read and Write will not time out.
136// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
137func (c *Conn) SetDeadline(t time.Time) error {
138	return c.conn.SetDeadline(t)
139}
140
141// SetReadDeadline sets the read deadline on the underlying connection.
142// A zero value for t means Read will not time out.
143func (c *Conn) SetReadDeadline(t time.Time) error {
144	return c.conn.SetReadDeadline(t)
145}
146
147// SetWriteDeadline sets the write deadline on the underlying connection.
148// A zero value for t means Write will not time out.
149// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error.
150func (c *Conn) SetWriteDeadline(t time.Time) error {
151	return c.conn.SetWriteDeadline(t)
152}
153
154// A halfConn represents one direction of the record layer
155// connection, either sending or receiving.
156type halfConn struct {
157	sync.Mutex
158
159	err     error       // first permanent error
160	version uint16      // protocol version
161	cipher  interface{} // cipher algorithm
162	mac     hash.Hash
163	seq     [8]byte // 64-bit sequence number
164
165	scratchBuf [13]byte // to avoid allocs; interface method args escape
166
167	nextCipher interface{} // next encryption state
168	nextMac    hash.Hash   // next MAC algorithm
169
170	trafficSecret []byte // current TLS 1.3 traffic secret
171}
172
173type permanentError struct {
174	err net.Error
175}
176
177func (e *permanentError) Error() string   { return e.err.Error() }
178func (e *permanentError) Unwrap() error   { return e.err }
179func (e *permanentError) Timeout() bool   { return e.err.Timeout() }
180func (e *permanentError) Temporary() bool { return false }
181
182func (hc *halfConn) setErrorLocked(err error) error {
183	if e, ok := err.(net.Error); ok {
184		hc.err = &permanentError{err: e}
185	} else {
186		hc.err = err
187	}
188	return hc.err
189}
190
191// prepareCipherSpec sets the encryption and MAC states
192// that a subsequent changeCipherSpec will use.
193func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac hash.Hash) {
194	hc.version = version
195	hc.nextCipher = cipher
196	hc.nextMac = mac
197}
198
199// changeCipherSpec changes the encryption and MAC states
200// to the ones previously passed to prepareCipherSpec.
201func (hc *halfConn) changeCipherSpec() error {
202	if hc.nextCipher == nil || hc.version == VersionTLS13 {
203		return alertInternalError
204	}
205	hc.cipher = hc.nextCipher
206	hc.mac = hc.nextMac
207	hc.nextCipher = nil
208	hc.nextMac = nil
209	for i := range hc.seq {
210		hc.seq[i] = 0
211	}
212	return nil
213}
214
215func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, secret []byte) {
216	hc.trafficSecret = secret
217	key, iv := suite.trafficKey(secret)
218	hc.cipher = suite.aead(key, iv)
219	for i := range hc.seq {
220		hc.seq[i] = 0
221	}
222}
223
224// incSeq increments the sequence number.
225func (hc *halfConn) incSeq() {
226	for i := 7; i >= 0; i-- {
227		hc.seq[i]++
228		if hc.seq[i] != 0 {
229			return
230		}
231	}
232
233	// Not allowed to let sequence number wrap.
234	// Instead, must renegotiate before it does.
235	// Not likely enough to bother.
236	panic("TLS: sequence number wraparound")
237}
238
239// explicitNonceLen returns the number of bytes of explicit nonce or IV included
240// in each record. Explicit nonces are present only in CBC modes after TLS 1.0
241// and in certain AEAD modes in TLS 1.2.
242func (hc *halfConn) explicitNonceLen() int {
243	if hc.cipher == nil {
244		return 0
245	}
246
247	switch c := hc.cipher.(type) {
248	case cipher.Stream:
249		return 0
250	case aead:
251		return c.explicitNonceLen()
252	case cbcMode:
253		// TLS 1.1 introduced a per-record explicit IV to fix the BEAST attack.
254		if hc.version >= VersionTLS11 {
255			return c.BlockSize()
256		}
257		return 0
258	default:
259		panic("unknown cipher type")
260	}
261}
262
263// extractPadding returns, in constant time, the length of the padding to remove
264// from the end of payload. It also returns a byte which is equal to 255 if the
265// padding was valid and 0 otherwise. See RFC 2246, Section 6.2.3.2.
266func extractPadding(payload []byte) (toRemove int, good byte) {
267	if len(payload) < 1 {
268		return 0, 0
269	}
270
271	paddingLen := payload[len(payload)-1]
272	t := uint(len(payload)-1) - uint(paddingLen)
273	// if len(payload) >= (paddingLen - 1) then the MSB of t is zero
274	good = byte(int32(^t) >> 31)
275
276	// The maximum possible padding length plus the actual length field
277	toCheck := 256
278	// The length of the padded data is public, so we can use an if here
279	if toCheck > len(payload) {
280		toCheck = len(payload)
281	}
282
283	for i := 0; i < toCheck; i++ {
284		t := uint(paddingLen) - uint(i)
285		// if i <= paddingLen then the MSB of t is zero
286		mask := byte(int32(^t) >> 31)
287		b := payload[len(payload)-1-i]
288		good &^= mask&paddingLen ^ mask&b
289	}
290
291	// We AND together the bits of good and replicate the result across
292	// all the bits.
293	good &= good << 4
294	good &= good << 2
295	good &= good << 1
296	good = uint8(int8(good) >> 7)
297
298	// Zero the padding length on error. This ensures any unchecked bytes
299	// are included in the MAC. Otherwise, an attacker that could
300	// distinguish MAC failures from padding failures could mount an attack
301	// similar to POODLE in SSL 3.0: given a good ciphertext that uses a
302	// full block's worth of padding, replace the final block with another
303	// block. If the MAC check passed but the padding check failed, the
304	// last byte of that block decrypted to the block size.
305	//
306	// See also macAndPaddingGood logic below.
307	paddingLen &= good
308
309	toRemove = int(paddingLen) + 1
310	return
311}
312
313func roundUp(a, b int) int {
314	return a + (b-a%b)%b
315}
316
317// cbcMode is an interface for block ciphers using cipher block chaining.
318type cbcMode interface {
319	cipher.BlockMode
320	SetIV([]byte)
321}
322
323// decrypt authenticates and decrypts the record if protection is active at
324// this stage. The returned plaintext might overlap with the input.
325func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
326	var plaintext []byte
327	typ := recordType(record[0])
328	payload := record[recordHeaderLen:]
329
330	// In TLS 1.3, change_cipher_spec messages are to be ignored without being
331	// decrypted. See RFC 8446, Appendix D.4.
332	if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
333		return payload, typ, nil
334	}
335
336	paddingGood := byte(255)
337	paddingLen := 0
338
339	explicitNonceLen := hc.explicitNonceLen()
340
341	if hc.cipher != nil {
342		switch c := hc.cipher.(type) {
343		case cipher.Stream:
344			c.XORKeyStream(payload, payload)
345		case aead:
346			if len(payload) < explicitNonceLen {
347				return nil, 0, alertBadRecordMAC
348			}
349			nonce := payload[:explicitNonceLen]
350			if len(nonce) == 0 {
351				nonce = hc.seq[:]
352			}
353			payload = payload[explicitNonceLen:]
354
355			var additionalData []byte
356			if hc.version == VersionTLS13 {
357				additionalData = record[:recordHeaderLen]
358			} else {
359				additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
360				additionalData = append(additionalData, record[:3]...)
361				n := len(payload) - c.Overhead()
362				additionalData = append(additionalData, byte(n>>8), byte(n))
363			}
364
365			var err error
366			plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
367			if err != nil {
368				return nil, 0, alertBadRecordMAC
369			}
370		case cbcMode:
371			blockSize := c.BlockSize()
372			minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
373			if len(payload)%blockSize != 0 || len(payload) < minPayload {
374				return nil, 0, alertBadRecordMAC
375			}
376
377			if explicitNonceLen > 0 {
378				c.SetIV(payload[:explicitNonceLen])
379				payload = payload[explicitNonceLen:]
380			}
381			c.CryptBlocks(payload, payload)
382
383			// In a limited attempt to protect against CBC padding oracles like
384			// Lucky13, the data past paddingLen (which is secret) is passed to
385			// the MAC function as extra data, to be fed into the HMAC after
386			// computing the digest. This makes the MAC roughly constant time as
387			// long as the digest computation is constant time and does not
388			// affect the subsequent write, modulo cache effects.
389			paddingLen, paddingGood = extractPadding(payload)
390		default:
391			panic("unknown cipher type")
392		}
393
394		if hc.version == VersionTLS13 {
395			if typ != recordTypeApplicationData {
396				return nil, 0, alertUnexpectedMessage
397			}
398			if len(plaintext) > maxPlaintext+1 {
399				return nil, 0, alertRecordOverflow
400			}
401			// Remove padding and find the ContentType scanning from the end.
402			for i := len(plaintext) - 1; i >= 0; i-- {
403				if plaintext[i] != 0 {
404					typ = recordType(plaintext[i])
405					plaintext = plaintext[:i]
406					break
407				}
408				if i == 0 {
409					return nil, 0, alertUnexpectedMessage
410				}
411			}
412		}
413	} else {
414		plaintext = payload
415	}
416
417	if hc.mac != nil {
418		macSize := hc.mac.Size()
419		if len(payload) < macSize {
420			return nil, 0, alertBadRecordMAC
421		}
422
423		n := len(payload) - macSize - paddingLen
424		n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 }
425		record[3] = byte(n >> 8)
426		record[4] = byte(n)
427		remoteMAC := payload[n : n+macSize]
428		localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
429
430		// This is equivalent to checking the MACs and paddingGood
431		// separately, but in constant-time to prevent distinguishing
432		// padding failures from MAC failures. Depending on what value
433		// of paddingLen was returned on bad padding, distinguishing
434		// bad MAC from bad padding can lead to an attack.
435		//
436		// See also the logic at the end of extractPadding.
437		macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
438		if macAndPaddingGood != 1 {
439			return nil, 0, alertBadRecordMAC
440		}
441
442		plaintext = payload[:n]
443	}
444
445	hc.incSeq()
446	return plaintext, typ, nil
447}
448
449// sliceForAppend extends the input slice by n bytes. head is the full extended
450// slice, while tail is the appended part. If the original slice has sufficient
451// capacity no allocation is performed.
452func sliceForAppend(in []byte, n int) (head, tail []byte) {
453	if total := len(in) + n; cap(in) >= total {
454		head = in[:total]
455	} else {
456		head = make([]byte, total)
457		copy(head, in)
458	}
459	tail = head[len(in):]
460	return
461}
462
463// encrypt encrypts payload, adding the appropriate nonce and/or MAC, and
464// appends it to record, which must already contain the record header.
465func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
466	if hc.cipher == nil {
467		return append(record, payload...), nil
468	}
469
470	var explicitNonce []byte
471	if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
472		record, explicitNonce = sliceForAppend(record, explicitNonceLen)
473		if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
474			// The AES-GCM construction in TLS has an explicit nonce so that the
475			// nonce can be random. However, the nonce is only 8 bytes which is
476			// too small for a secure, random nonce. Therefore we use the
477			// sequence number as the nonce. The 3DES-CBC construction also has
478			// an 8 bytes nonce but its nonces must be unpredictable (see RFC
479			// 5246, Appendix F.3), forcing us to use randomness. That's not
480			// 3DES' biggest problem anyway because the birthday bound on block
481			// collision is reached first due to its similarly small block size
482			// (see the Sweet32 attack).
483			copy(explicitNonce, hc.seq[:])
484		} else {
485			if _, err := io.ReadFull(rand, explicitNonce); err != nil {
486				return nil, err
487			}
488		}
489	}
490
491	var dst []byte
492	switch c := hc.cipher.(type) {
493	case cipher.Stream:
494		mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
495		record, dst = sliceForAppend(record, len(payload)+len(mac))
496		c.XORKeyStream(dst[:len(payload)], payload)
497		c.XORKeyStream(dst[len(payload):], mac)
498	case aead:
499		nonce := explicitNonce
500		if len(nonce) == 0 {
501			nonce = hc.seq[:]
502		}
503
504		if hc.version == VersionTLS13 {
505			record = append(record, payload...)
506
507			// Encrypt the actual ContentType and replace the plaintext one.
508			record = append(record, record[0])
509			record[0] = byte(recordTypeApplicationData)
510
511			n := len(payload) + 1 + c.Overhead()
512			record[3] = byte(n >> 8)
513			record[4] = byte(n)
514
515			record = c.Seal(record[:recordHeaderLen],
516				nonce, record[recordHeaderLen:], record[:recordHeaderLen])
517		} else {
518			additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
519			additionalData = append(additionalData, record[:recordHeaderLen]...)
520			record = c.Seal(record, nonce, payload, additionalData)
521		}
522	case cbcMode:
523		mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
524		blockSize := c.BlockSize()
525		plaintextLen := len(payload) + len(mac)
526		paddingLen := blockSize - plaintextLen%blockSize
527		record, dst = sliceForAppend(record, plaintextLen+paddingLen)
528		copy(dst, payload)
529		copy(dst[len(payload):], mac)
530		for i := plaintextLen; i < len(dst); i++ {
531			dst[i] = byte(paddingLen - 1)
532		}
533		if len(explicitNonce) > 0 {
534			c.SetIV(explicitNonce)
535		}
536		c.CryptBlocks(dst, dst)
537	default:
538		panic("unknown cipher type")
539	}
540
541	// Update length to include nonce, MAC and any block padding needed.
542	n := len(record) - recordHeaderLen
543	record[3] = byte(n >> 8)
544	record[4] = byte(n)
545	hc.incSeq()
546
547	return record, nil
548}
549
550// RecordHeaderError is returned when a TLS record header is invalid.
551type RecordHeaderError struct {
552	// Msg contains a human readable string that describes the error.
553	Msg string
554	// RecordHeader contains the five bytes of TLS record header that
555	// triggered the error.
556	RecordHeader [5]byte
557	// Conn provides the underlying net.Conn in the case that a client
558	// sent an initial handshake that didn't look like TLS.
559	// It is nil if there's already been a handshake or a TLS alert has
560	// been written to the connection.
561	Conn net.Conn
562}
563
564func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
565
566func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
567	err.Msg = msg
568	err.Conn = conn
569	copy(err.RecordHeader[:], c.rawInput.Bytes())
570	return err
571}
572
573func (c *Conn) readRecord() error {
574	return c.readRecordOrCCS(false)
575}
576
577func (c *Conn) readChangeCipherSpec() error {
578	return c.readRecordOrCCS(true)
579}
580
581// readRecordOrCCS reads one or more TLS records from the connection and
582// updates the record layer state. Some invariants:
583//   * c.in must be locked
584//   * c.input must be empty
585// During the handshake one and only one of the following will happen:
586//   - c.hand grows
587//   - c.in.changeCipherSpec is called
588//   - an error is returned
589// After the handshake one and only one of the following will happen:
590//   - c.hand grows
591//   - c.input is set
592//   - an error is returned
593func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
594	if c.in.err != nil {
595		return c.in.err
596	}
597	handshakeComplete := c.handshakeComplete()
598
599	// This function modifies c.rawInput, which owns the c.input memory.
600	if c.input.Len() != 0 {
601		return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
602	}
603	c.input.Reset(nil)
604
605	// Read header, payload.
606	if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
607		// RFC 8446, Section 6.1 suggests that EOF without an alertCloseNotify
608		// is an error, but popular web sites seem to do this, so we accept it
609		// if and only if at the record boundary.
610		if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
611			err = io.EOF
612		}
613		if e, ok := err.(net.Error); !ok || !e.Temporary() {
614			c.in.setErrorLocked(err)
615		}
616		return err
617	}
618	hdr := c.rawInput.Bytes()[:recordHeaderLen]
619	typ := recordType(hdr[0])
620
621	// No valid TLS record has a type of 0x80, however SSLv2 handshakes
622	// start with a uint16 length where the MSB is set and the first record
623	// is always < 256 bytes long. Therefore typ == 0x80 strongly suggests
624	// an SSLv2 client.
625	if !handshakeComplete && typ == 0x80 {
626		c.sendAlert(alertProtocolVersion)
627		return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
628	}
629
630	vers := uint16(hdr[1])<<8 | uint16(hdr[2])
631	n := int(hdr[3])<<8 | int(hdr[4])
632	if c.haveVers && c.vers != VersionTLS13 && vers != c.vers {
633		c.sendAlert(alertProtocolVersion)
634		msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
635		return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
636	}
637	if !c.haveVers {
638		// First message, be extra suspicious: this might not be a TLS
639		// client. Bail out before reading a full 'body', if possible.
640		// The current max version is 3.3 so if the version is >= 16.0,
641		// it's probably not real.
642		if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
643			return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
644		}
645	}
646	if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
647		c.sendAlert(alertRecordOverflow)
648		msg := fmt.Sprintf("oversized record received with length %d", n)
649		return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
650	}
651	if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
652		if e, ok := err.(net.Error); !ok || !e.Temporary() {
653			c.in.setErrorLocked(err)
654		}
655		return err
656	}
657
658	// Process message.
659	record := c.rawInput.Next(recordHeaderLen + n)
660	data, typ, err := c.in.decrypt(record)
661	if err != nil {
662		return c.in.setErrorLocked(c.sendAlert(err.(alert)))
663	}
664	if len(data) > maxPlaintext {
665		return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
666	}
667
668	// Application Data messages are always protected.
669	if c.in.cipher == nil && typ == recordTypeApplicationData {
670		return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
671	}
672
673	if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
674		// This is a state-advancing message: reset the retry count.
675		c.retryCount = 0
676	}
677
678	// Handshake messages MUST NOT be interleaved with other record types in TLS 1.3.
679	if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
680		return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
681	}
682
683	switch typ {
684	default:
685		return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
686
687	case recordTypeAlert:
688		if len(data) != 2 {
689			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
690		}
691		if alert(data[1]) == alertCloseNotify {
692			return c.in.setErrorLocked(io.EOF)
693		}
694		if c.vers == VersionTLS13 {
695			return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
696		}
697		switch data[0] {
698		case alertLevelWarning:
699			// Drop the record on the floor and retry.
700			return c.retryReadRecord(expectChangeCipherSpec)
701		case alertLevelError:
702			return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
703		default:
704			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
705		}
706
707	case recordTypeChangeCipherSpec:
708		if len(data) != 1 || data[0] != 1 {
709			return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
710		}
711		// Handshake messages are not allowed to fragment across the CCS.
712		if c.hand.Len() > 0 {
713			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
714		}
715		// In TLS 1.3, change_cipher_spec records are ignored until the
716		// Finished. See RFC 8446, Appendix D.4. Note that according to Section
717		// 5, a server can send a ChangeCipherSpec before its ServerHello, when
718		// c.vers is still unset. That's not useful though and suspicious if the
719		// server then selects a lower protocol version, so don't allow that.
720		if c.vers == VersionTLS13 {
721			return c.retryReadRecord(expectChangeCipherSpec)
722		}
723		if !expectChangeCipherSpec {
724			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
725		}
726		if err := c.in.changeCipherSpec(); err != nil {
727			return c.in.setErrorLocked(c.sendAlert(err.(alert)))
728		}
729
730	case recordTypeApplicationData:
731		if !handshakeComplete || expectChangeCipherSpec {
732			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
733		}
734		// Some OpenSSL servers send empty records in order to randomize the
735		// CBC IV. Ignore a limited number of empty records.
736		if len(data) == 0 {
737			return c.retryReadRecord(expectChangeCipherSpec)
738		}
739		// Note that data is owned by c.rawInput, following the Next call above,
740		// to avoid copying the plaintext. This is safe because c.rawInput is
741		// not read from or written to until c.input is drained.
742		c.input.Reset(data)
743
744	case recordTypeHandshake:
745		if len(data) == 0 || expectChangeCipherSpec {
746			return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
747		}
748		c.hand.Write(data)
749	}
750
751	return nil
752}
753
754// retryReadRecord recurses into readRecordOrCCS to drop a non-advancing record, like
755// a warning alert, empty application_data, or a change_cipher_spec in TLS 1.3.
756func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
757	c.retryCount++
758	if c.retryCount > maxUselessRecords {
759		c.sendAlert(alertUnexpectedMessage)
760		return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
761	}
762	return c.readRecordOrCCS(expectChangeCipherSpec)
763}
764
765// atLeastReader reads from R, stopping with EOF once at least N bytes have been
766// read. It is different from an io.LimitedReader in that it doesn't cut short
767// the last Read call, and in that it considers an early EOF an error.
768type atLeastReader struct {
769	R io.Reader
770	N int64
771}
772
773func (r *atLeastReader) Read(p []byte) (int, error) {
774	if r.N <= 0 {
775		return 0, io.EOF
776	}
777	n, err := r.R.Read(p)
778	r.N -= int64(n) // won't underflow unless len(p) >= n > 9223372036854775809
779	if r.N > 0 && err == io.EOF {
780		return n, io.ErrUnexpectedEOF
781	}
782	if r.N <= 0 && err == nil {
783		return n, io.EOF
784	}
785	return n, err
786}
787
788// readFromUntil reads from r into c.rawInput until c.rawInput contains
789// at least n bytes or else returns an error.
790func (c *Conn) readFromUntil(r io.Reader, n int) error {
791	if c.rawInput.Len() >= n {
792		return nil
793	}
794	needs := n - c.rawInput.Len()
795	// There might be extra input waiting on the wire. Make a best effort
796	// attempt to fetch it so that it can be used in (*Conn).Read to
797	// "predict" closeNotify alerts.
798	c.rawInput.Grow(needs + bytes.MinRead)
799	_, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
800	return err
801}
802
803// sendAlert sends a TLS alert message.
804func (c *Conn) sendAlertLocked(err alert) error {
805	switch err {
806	case alertNoRenegotiation, alertCloseNotify:
807		c.tmp[0] = alertLevelWarning
808	default:
809		c.tmp[0] = alertLevelError
810	}
811	c.tmp[1] = byte(err)
812
813	_, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
814	if err == alertCloseNotify {
815		// closeNotify is a special case in that it isn't an error.
816		return writeErr
817	}
818
819	return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
820}
821
822// sendAlert sends a TLS alert message.
823func (c *Conn) sendAlert(err alert) error {
824	c.out.Lock()
825	defer c.out.Unlock()
826	return c.sendAlertLocked(err)
827}
828
829const (
830	// tcpMSSEstimate is a conservative estimate of the TCP maximum segment
831	// size (MSS). A constant is used, rather than querying the kernel for
832	// the actual MSS, to avoid complexity. The value here is the IPv6
833	// minimum MTU (1280 bytes) minus the overhead of an IPv6 header (40
834	// bytes) and a TCP header with timestamps (32 bytes).
835	tcpMSSEstimate = 1208
836
837	// recordSizeBoostThreshold is the number of bytes of application data
838	// sent after which the TLS record size will be increased to the
839	// maximum.
840	recordSizeBoostThreshold = 128 * 1024
841)
842
843// maxPayloadSizeForWrite returns the maximum TLS payload size to use for the
844// next application data record. There is the following trade-off:
845//
846//   - For latency-sensitive applications, such as web browsing, each TLS
847//     record should fit in one TCP segment.
848//   - For throughput-sensitive applications, such as large file transfers,
849//     larger TLS records better amortize framing and encryption overheads.
850//
851// A simple heuristic that works well in practice is to use small records for
852// the first 1MB of data, then use larger records for subsequent data, and
853// reset back to smaller records after the connection becomes idle. See "High
854// Performance Web Networking", Chapter 4, or:
855// https://www.igvita.com/2013/10/24/optimizing-tls-record-size-and-buffering-latency/
856//
857// In the interests of simplicity and determinism, this code does not attempt
858// to reset the record size once the connection is idle, however.
859func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
860	if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
861		return maxPlaintext
862	}
863
864	if c.bytesSent >= recordSizeBoostThreshold {
865		return maxPlaintext
866	}
867
868	// Subtract TLS overheads to get the maximum payload size.
869	payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
870	if c.out.cipher != nil {
871		switch ciph := c.out.cipher.(type) {
872		case cipher.Stream:
873			payloadBytes -= c.out.mac.Size()
874		case cipher.AEAD:
875			payloadBytes -= ciph.Overhead()
876		case cbcMode:
877			blockSize := ciph.BlockSize()
878			// The payload must fit in a multiple of blockSize, with
879			// room for at least one padding byte.
880			payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
881			// The MAC is appended before padding so affects the
882			// payload size directly.
883			payloadBytes -= c.out.mac.Size()
884		default:
885			panic("unknown cipher type")
886		}
887	}
888	if c.vers == VersionTLS13 {
889		payloadBytes-- // encrypted ContentType
890	}
891
892	// Allow packet growth in arithmetic progression up to max.
893	pkt := c.packetsSent
894	c.packetsSent++
895	if pkt > 1000 {
896		return maxPlaintext // avoid overflow in multiply below
897	}
898
899	n := payloadBytes * int(pkt+1)
900	if n > maxPlaintext {
901		n = maxPlaintext
902	}
903	return n
904}
905
906func (c *Conn) write(data []byte) (int, error) {
907	if c.buffering {
908		c.sendBuf = append(c.sendBuf, data...)
909		return len(data), nil
910	}
911
912	n, err := c.conn.Write(data)
913	c.bytesSent += int64(n)
914	return n, err
915}
916
917func (c *Conn) flush() (int, error) {
918	if len(c.sendBuf) == 0 {
919		return 0, nil
920	}
921
922	n, err := c.conn.Write(c.sendBuf)
923	c.bytesSent += int64(n)
924	c.sendBuf = nil
925	c.buffering = false
926	return n, err
927}
928
929// outBufPool pools the record-sized scratch buffers used by writeRecordLocked.
930var outBufPool = sync.Pool{
931	New: func() interface{} {
932		return new([]byte)
933	},
934}
935
936// writeRecordLocked writes a TLS record with the given type and payload to the
937// connection and updates the record layer state.
938func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
939	outBufPtr := outBufPool.Get().(*[]byte)
940	outBuf := *outBufPtr
941	defer func() {
942		// You might be tempted to simplify this by just passing &outBuf to Put,
943		// but that would make the local copy of the outBuf slice header escape
944		// to the heap, causing an allocation. Instead, we keep around the
945		// pointer to the slice header returned by Get, which is already on the
946		// heap, and overwrite and return that.
947		*outBufPtr = outBuf
948		outBufPool.Put(outBufPtr)
949	}()
950
951	var n int
952	for len(data) > 0 {
953		m := len(data)
954		if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
955			m = maxPayload
956		}
957
958		_, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
959		outBuf[0] = byte(typ)
960		vers := c.vers
961		if vers == 0 {
962			// Some TLS servers fail if the record version is
963			// greater than TLS 1.0 for the initial ClientHello.
964			vers = VersionTLS10
965		} else if vers == VersionTLS13 {
966			// TLS 1.3 froze the record layer version to 1.2.
967			// See RFC 8446, Section 5.1.
968			vers = VersionTLS12
969		}
970		outBuf[1] = byte(vers >> 8)
971		outBuf[2] = byte(vers)
972		outBuf[3] = byte(m >> 8)
973		outBuf[4] = byte(m)
974
975		var err error
976		outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
977		if err != nil {
978			return n, err
979		}
980		if _, err := c.write(outBuf); err != nil {
981			return n, err
982		}
983		n += m
984		data = data[m:]
985	}
986
987	if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
988		if err := c.out.changeCipherSpec(); err != nil {
989			return n, c.sendAlertLocked(err.(alert))
990		}
991	}
992
993	return n, nil
994}
995
996// writeRecord writes a TLS record with the given type and payload to the
997// connection and updates the record layer state.
998func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
999	c.out.Lock()
1000	defer c.out.Unlock()
1001
1002	return c.writeRecordLocked(typ, data)
1003}
1004
1005// readHandshake reads the next handshake message from
1006// the record layer.
1007func (c *Conn) readHandshake() (interface{}, error) {
1008	for c.hand.Len() < 4 {
1009		if err := c.readRecord(); err != nil {
1010			return nil, err
1011		}
1012	}
1013
1014	data := c.hand.Bytes()
1015	n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1016	if n > maxHandshake {
1017		c.sendAlertLocked(alertInternalError)
1018		return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
1019	}
1020	for c.hand.Len() < 4+n {
1021		if err := c.readRecord(); err != nil {
1022			return nil, err
1023		}
1024	}
1025	data = c.hand.Next(4 + n)
1026	var m handshakeMessage
1027	switch data[0] {
1028	case typeHelloRequest:
1029		m = new(helloRequestMsg)
1030	case typeClientHello:
1031		m = new(clientHelloMsg)
1032	case typeServerHello:
1033		m = new(serverHelloMsg)
1034	case typeNewSessionTicket:
1035		if c.vers == VersionTLS13 {
1036			m = new(newSessionTicketMsgTLS13)
1037		} else {
1038			m = new(newSessionTicketMsg)
1039		}
1040	case typeCertificate:
1041		if c.vers == VersionTLS13 {
1042			m = new(certificateMsgTLS13)
1043		} else {
1044			m = new(certificateMsg)
1045		}
1046	case typeCertificateRequest:
1047		if c.vers == VersionTLS13 {
1048			m = new(certificateRequestMsgTLS13)
1049		} else {
1050			m = &certificateRequestMsg{
1051				hasSignatureAlgorithm: c.vers >= VersionTLS12,
1052			}
1053		}
1054	case typeCertificateStatus:
1055		m = new(certificateStatusMsg)
1056	case typeServerKeyExchange:
1057		m = new(serverKeyExchangeMsg)
1058	case typeServerHelloDone:
1059		m = new(serverHelloDoneMsg)
1060	case typeClientKeyExchange:
1061		m = new(clientKeyExchangeMsg)
1062	case typeCertificateVerify:
1063		m = &certificateVerifyMsg{
1064			hasSignatureAlgorithm: c.vers >= VersionTLS12,
1065		}
1066	case typeFinished:
1067		m = new(finishedMsg)
1068	case typeEncryptedExtensions:
1069		m = new(encryptedExtensionsMsg)
1070	case typeEndOfEarlyData:
1071		m = new(endOfEarlyDataMsg)
1072	case typeKeyUpdate:
1073		m = new(keyUpdateMsg)
1074	default:
1075		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1076	}
1077
1078	// The handshake message unmarshalers
1079	// expect to be able to keep references to data,
1080	// so pass in a fresh copy that won't be overwritten.
1081	data = append([]byte(nil), data...)
1082
1083	if !m.unmarshal(data) {
1084		return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1085	}
1086	return m, nil
1087}
1088
1089var (
1090	errShutdown = errors.New("tls: protocol is shutdown")
1091)
1092
1093// Write writes data to the connection.
1094//
1095// As Write calls Handshake, in order to prevent indefinite blocking a deadline
1096// must be set for both Read and Write before Write is called when the handshake
1097// has not yet completed. See SetDeadline, SetReadDeadline, and
1098// SetWriteDeadline.
1099func (c *Conn) Write(b []byte) (int, error) {
1100	// interlock with Close below
1101	for {
1102		x := atomic.LoadInt32(&c.activeCall)
1103		if x&1 != 0 {
1104			return 0, net.ErrClosed
1105		}
1106		if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1107			break
1108		}
1109	}
1110	defer atomic.AddInt32(&c.activeCall, -2)
1111
1112	if err := c.Handshake(); err != nil {
1113		return 0, err
1114	}
1115
1116	c.out.Lock()
1117	defer c.out.Unlock()
1118
1119	if err := c.out.err; err != nil {
1120		return 0, err
1121	}
1122
1123	if !c.handshakeComplete() {
1124		return 0, alertInternalError
1125	}
1126
1127	if c.closeNotifySent {
1128		return 0, errShutdown
1129	}
1130
1131	// TLS 1.0 is susceptible to a chosen-plaintext
1132	// attack when using block mode ciphers due to predictable IVs.
1133	// This can be prevented by splitting each Application Data
1134	// record into two records, effectively randomizing the IV.
1135	//
1136	// https://www.openssl.org/~bodo/tls-cbc.txt
1137	// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
1138	// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
1139
1140	var m int
1141	if len(b) > 1 && c.vers == VersionTLS10 {
1142		if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1143			n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1144			if err != nil {
1145				return n, c.out.setErrorLocked(err)
1146			}
1147			m, b = 1, b[1:]
1148		}
1149	}
1150
1151	n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1152	return n + m, c.out.setErrorLocked(err)
1153}
1154
1155// handleRenegotiation processes a HelloRequest handshake message.
1156func (c *Conn) handleRenegotiation() error {
1157	if c.vers == VersionTLS13 {
1158		return errors.New("tls: internal error: unexpected renegotiation")
1159	}
1160
1161	msg, err := c.readHandshake()
1162	if err != nil {
1163		return err
1164	}
1165
1166	helloReq, ok := msg.(*helloRequestMsg)
1167	if !ok {
1168		c.sendAlert(alertUnexpectedMessage)
1169		return unexpectedMessageError(helloReq, msg)
1170	}
1171
1172	if !c.isClient {
1173		return c.sendAlert(alertNoRenegotiation)
1174	}
1175
1176	switch c.config.Renegotiation {
1177	case RenegotiateNever:
1178		return c.sendAlert(alertNoRenegotiation)
1179	case RenegotiateOnceAsClient:
1180		if c.handshakes > 1 {
1181			return c.sendAlert(alertNoRenegotiation)
1182		}
1183	case RenegotiateFreelyAsClient:
1184		// Ok.
1185	default:
1186		c.sendAlert(alertInternalError)
1187		return errors.New("tls: unknown Renegotiation value")
1188	}
1189
1190	c.handshakeMutex.Lock()
1191	defer c.handshakeMutex.Unlock()
1192
1193	atomic.StoreUint32(&c.handshakeStatus, 0)
1194	if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1195		c.handshakes++
1196	}
1197	return c.handshakeErr
1198}
1199
1200// handlePostHandshakeMessage processes a handshake message arrived after the
1201// handshake is complete. Up to TLS 1.2, it indicates the start of a renegotiation.
1202func (c *Conn) handlePostHandshakeMessage() error {
1203	if c.vers != VersionTLS13 {
1204		return c.handleRenegotiation()
1205	}
1206
1207	msg, err := c.readHandshake()
1208	if err != nil {
1209		return err
1210	}
1211
1212	c.retryCount++
1213	if c.retryCount > maxUselessRecords {
1214		c.sendAlert(alertUnexpectedMessage)
1215		return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1216	}
1217
1218	switch msg := msg.(type) {
1219	case *newSessionTicketMsgTLS13:
1220		return c.handleNewSessionTicket(msg)
1221	case *keyUpdateMsg:
1222		return c.handleKeyUpdate(msg)
1223	default:
1224		c.sendAlert(alertUnexpectedMessage)
1225		return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1226	}
1227}
1228
1229func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1230	cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1231	if cipherSuite == nil {
1232		return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1233	}
1234
1235	newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1236	c.in.setTrafficSecret(cipherSuite, newSecret)
1237
1238	if keyUpdate.updateRequested {
1239		c.out.Lock()
1240		defer c.out.Unlock()
1241
1242		msg := &keyUpdateMsg{}
1243		_, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
1244		if err != nil {
1245			// Surface the error at the next write.
1246			c.out.setErrorLocked(err)
1247			return nil
1248		}
1249
1250		newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1251		c.out.setTrafficSecret(cipherSuite, newSecret)
1252	}
1253
1254	return nil
1255}
1256
1257// Read reads data from the connection.
1258//
1259// As Read calls Handshake, in order to prevent indefinite blocking a deadline
1260// must be set for both Read and Write before Read is called when the handshake
1261// has not yet completed. See SetDeadline, SetReadDeadline, and
1262// SetWriteDeadline.
1263func (c *Conn) Read(b []byte) (int, error) {
1264	if err := c.Handshake(); err != nil {
1265		return 0, err
1266	}
1267	if len(b) == 0 {
1268		// Put this after Handshake, in case people were calling
1269		// Read(nil) for the side effect of the Handshake.
1270		return 0, nil
1271	}
1272
1273	c.in.Lock()
1274	defer c.in.Unlock()
1275
1276	for c.input.Len() == 0 {
1277		if err := c.readRecord(); err != nil {
1278			return 0, err
1279		}
1280		for c.hand.Len() > 0 {
1281			if err := c.handlePostHandshakeMessage(); err != nil {
1282				return 0, err
1283			}
1284		}
1285	}
1286
1287	n, _ := c.input.Read(b)
1288
1289	// If a close-notify alert is waiting, read it so that we can return (n,
1290	// EOF) instead of (n, nil), to signal to the HTTP response reading
1291	// goroutine that the connection is now closed. This eliminates a race
1292	// where the HTTP response reading goroutine would otherwise not observe
1293	// the EOF until its next read, by which time a client goroutine might
1294	// have already tried to reuse the HTTP connection for a new request.
1295	// See https://golang.org/cl/76400046 and https://golang.org/issue/3514
1296	if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1297		recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1298		if err := c.readRecord(); err != nil {
1299			return n, err // will be io.EOF on closeNotify
1300		}
1301	}
1302
1303	return n, nil
1304}
1305
1306// Close closes the connection.
1307func (c *Conn) Close() error {
1308	// Interlock with Conn.Write above.
1309	var x int32
1310	for {
1311		x = atomic.LoadInt32(&c.activeCall)
1312		if x&1 != 0 {
1313			return net.ErrClosed
1314		}
1315		if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1316			break
1317		}
1318	}
1319	if x != 0 {
1320		// io.Writer and io.Closer should not be used concurrently.
1321		// If Close is called while a Write is currently in-flight,
1322		// interpret that as a sign that this Close is really just
1323		// being used to break the Write and/or clean up resources and
1324		// avoid sending the alertCloseNotify, which may block
1325		// waiting on handshakeMutex or the c.out mutex.
1326		return c.conn.Close()
1327	}
1328
1329	var alertErr error
1330	if c.handshakeComplete() {
1331		if err := c.closeNotify(); err != nil {
1332			alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1333		}
1334	}
1335
1336	if err := c.conn.Close(); err != nil {
1337		return err
1338	}
1339	return alertErr
1340}
1341
1342var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1343
1344// CloseWrite shuts down the writing side of the connection. It should only be
1345// called once the handshake has completed and does not call CloseWrite on the
1346// underlying connection. Most callers should just use Close.
1347func (c *Conn) CloseWrite() error {
1348	if !c.handshakeComplete() {
1349		return errEarlyCloseWrite
1350	}
1351
1352	return c.closeNotify()
1353}
1354
1355func (c *Conn) closeNotify() error {
1356	c.out.Lock()
1357	defer c.out.Unlock()
1358
1359	if !c.closeNotifySent {
1360		// Set a Write Deadline to prevent possibly blocking forever.
1361		c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1362		c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1363		c.closeNotifySent = true
1364		// Any subsequent writes will fail.
1365		c.SetWriteDeadline(time.Now())
1366	}
1367	return c.closeNotifyErr
1368}
1369
1370// Handshake runs the client or server handshake
1371// protocol if it has not yet been run.
1372//
1373// Most uses of this package need not call Handshake explicitly: the
1374// first Read or Write will call it automatically.
1375//
1376// For control over canceling or setting a timeout on a handshake, use
1377// HandshakeContext or the Dialer's DialContext method instead.
1378func (c *Conn) Handshake() error {
1379	return c.HandshakeContext(context.Background())
1380}
1381
1382// HandshakeContext runs the client or server handshake
1383// protocol if it has not yet been run.
1384//
1385// The provided Context must be non-nil. If the context is canceled before
1386// the handshake is complete, the handshake is interrupted and an error is returned.
1387// Once the handshake has completed, cancellation of the context will not affect the
1388// connection.
1389//
1390// Most uses of this package need not call HandshakeContext explicitly: the
1391// first Read or Write will call it automatically.
1392func (c *Conn) HandshakeContext(ctx context.Context) error {
1393	// Delegate to unexported method for named return
1394	// without confusing documented signature.
1395	return c.handshakeContext(ctx)
1396}
1397
1398func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1399	handshakeCtx, cancel := context.WithCancel(ctx)
1400	// Note: defer this before starting the "interrupter" goroutine
1401	// so that we can tell the difference between the input being canceled and
1402	// this cancellation. In the former case, we need to close the connection.
1403	defer cancel()
1404
1405	// Start the "interrupter" goroutine, if this context might be canceled.
1406	// (The background context cannot).
1407	//
1408	// The interrupter goroutine waits for the input context to be done and
1409	// closes the connection if this happens before the function returns.
1410	if ctx.Done() != nil {
1411		done := make(chan struct{})
1412		interruptRes := make(chan error, 1)
1413		defer func() {
1414			close(done)
1415			if ctxErr := <-interruptRes; ctxErr != nil {
1416				// Return context error to user.
1417				ret = ctxErr
1418			}
1419		}()
1420		go func() {
1421			select {
1422			case <-handshakeCtx.Done():
1423				// Close the connection, discarding the error
1424				_ = c.conn.Close()
1425				interruptRes <- handshakeCtx.Err()
1426			case <-done:
1427				interruptRes <- nil
1428			}
1429		}()
1430	}
1431
1432	c.handshakeMutex.Lock()
1433	defer c.handshakeMutex.Unlock()
1434
1435	if err := c.handshakeErr; err != nil {
1436		return err
1437	}
1438	if c.handshakeComplete() {
1439		return nil
1440	}
1441
1442	c.in.Lock()
1443	defer c.in.Unlock()
1444
1445	c.handshakeErr = c.handshakeFn(handshakeCtx)
1446	if c.handshakeErr == nil {
1447		c.handshakes++
1448	} else {
1449		// If an error occurred during the handshake try to flush the
1450		// alert that might be left in the buffer.
1451		c.flush()
1452	}
1453
1454	if c.handshakeErr == nil && !c.handshakeComplete() {
1455		c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1456	}
1457
1458	return c.handshakeErr
1459}
1460
1461// ConnectionState returns basic TLS details about the connection.
1462func (c *Conn) ConnectionState() ConnectionState {
1463	c.handshakeMutex.Lock()
1464	defer c.handshakeMutex.Unlock()
1465	return c.connectionStateLocked()
1466}
1467
1468func (c *Conn) connectionStateLocked() ConnectionState {
1469	var state ConnectionState
1470	state.HandshakeComplete = c.handshakeComplete()
1471	state.Version = c.vers
1472	state.NegotiatedProtocol = c.clientProtocol
1473	state.DidResume = c.didResume
1474	state.NegotiatedProtocolIsMutual = true
1475	state.ServerName = c.serverName
1476	state.CipherSuite = c.cipherSuite
1477	state.PeerCertificates = c.peerCertificates
1478	state.VerifiedChains = c.verifiedChains
1479	state.SignedCertificateTimestamps = c.scts
1480	state.OCSPResponse = c.ocspResponse
1481	if !c.didResume && c.vers != VersionTLS13 {
1482		if c.clientFinishedIsFirst {
1483			state.TLSUnique = c.clientFinished[:]
1484		} else {
1485			state.TLSUnique = c.serverFinished[:]
1486		}
1487	}
1488	if c.config.Renegotiation != RenegotiateNever {
1489		state.ekm = noExportedKeyingMaterial
1490	} else {
1491		state.ekm = c.ekm
1492	}
1493	return state
1494}
1495
1496// OCSPResponse returns the stapled OCSP response from the TLS server, if
1497// any. (Only valid for client connections.)
1498func (c *Conn) OCSPResponse() []byte {
1499	c.handshakeMutex.Lock()
1500	defer c.handshakeMutex.Unlock()
1501
1502	return c.ocspResponse
1503}
1504
1505// VerifyHostname checks that the peer certificate chain is valid for
1506// connecting to host. If so, it returns nil; if not, it returns an error
1507// describing the problem.
1508func (c *Conn) VerifyHostname(host string) error {
1509	c.handshakeMutex.Lock()
1510	defer c.handshakeMutex.Unlock()
1511	if !c.isClient {
1512		return errors.New("tls: VerifyHostname called on TLS server connection")
1513	}
1514	if !c.handshakeComplete() {
1515		return errors.New("tls: handshake has not yet been performed")
1516	}
1517	if len(c.verifiedChains) == 0 {
1518		return errors.New("tls: handshake did not verify certificate chain")
1519	}
1520	return c.peerCertificates[0].VerifyHostname(host)
1521}
1522
1523func (c *Conn) handshakeComplete() bool {
1524	return atomic.LoadUint32(&c.handshakeStatus) == 1
1525}
1526