1// Copyright 2017 Google Inc. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bufio"
9	"bytes"
10	"crypto/cipher"
11	"encoding/binary"
12	"errors"
13	"fmt"
14	"io"
15	"net"
16	"strconv"
17	"sync/atomic"
18)
19
20type UConn struct {
21	*Conn
22
23	Extensions    []TLSExtension
24	ClientHelloID ClientHelloID
25
26	ClientHelloBuilt bool
27	HandshakeState   ClientHandshakeState
28
29	// sessionID may or may not depend on ticket; nil => random
30	GetSessionID func(ticket []byte) [32]byte
31
32	greaseSeed [ssl_grease_last_index]uint16
33
34	omitSNIExtension bool
35}
36
37// UClient returns a new uTLS client, with behavior depending on clientHelloID.
38// Config CAN be nil, but make sure to eventually specify ServerName.
39func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
40	if config == nil {
41		config = &Config{}
42	}
43	tlsConn := Conn{conn: conn, config: config, isClient: true}
44	handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
45	uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
46	uconn.HandshakeState.uconn = &uconn
47	return &uconn
48}
49
50// BuildHandshakeState behavior varies based on ClientHelloID and
51// whether it was already called before.
52// If HelloGolang:
53//   [only once] make default ClientHello and overwrite existing state
54// If any other mimicking ClientHelloID is used:
55//   [only once] make ClientHello based on ID and overwrite existing state
56//   [each call] apply uconn.Extensions config to internal crypto/tls structures
57//   [each call] marshal ClientHello.
58//
59// BuildHandshakeState is automatically called before uTLS performs handshake,
60// amd should only be called explicitly to inspect/change fields of
61// default/mimicked ClientHello.
62func (uconn *UConn) BuildHandshakeState() error {
63	if uconn.ClientHelloID == HelloGolang {
64		if uconn.ClientHelloBuilt {
65			return nil
66		}
67
68		// use default Golang ClientHello.
69		hello, ecdheParams, err := uconn.makeClientHello()
70		if err != nil {
71			return err
72		}
73
74		uconn.HandshakeState.Hello = hello.getPublicPtr()
75		uconn.HandshakeState.State13.EcdheParams = ecdheParams
76		uconn.HandshakeState.C = uconn.Conn
77	} else {
78		if !uconn.ClientHelloBuilt {
79			err := uconn.applyPresetByID(uconn.ClientHelloID)
80			if err != nil {
81				return err
82			}
83			if uconn.omitSNIExtension {
84				uconn.removeSNIExtension()
85			}
86		}
87
88		err := uconn.ApplyConfig()
89		if err != nil {
90			return err
91		}
92		err = uconn.MarshalClientHello()
93		if err != nil {
94			return err
95		}
96	}
97	uconn.ClientHelloBuilt = true
98	return nil
99}
100
101// SetSessionState sets the session ticket, which may be preshared or fake.
102// If session is nil, the body of session ticket extension will be unset,
103// but the extension itself still MAY be present for mimicking purposes.
104// Session tickets to be reused - use same cache on following connections.
105func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
106	uconn.HandshakeState.Session = session
107	var sessionTicket []uint8
108	if session != nil {
109		sessionTicket = session.sessionTicket
110	}
111	uconn.HandshakeState.Hello.TicketSupported = true
112	uconn.HandshakeState.Hello.SessionTicket = sessionTicket
113
114	for _, ext := range uconn.Extensions {
115		st, ok := ext.(*SessionTicketExtension)
116		if !ok {
117			continue
118		}
119		st.Session = session
120		if session != nil {
121			if len(session.SessionTicket()) > 0 {
122				if uconn.GetSessionID != nil {
123					sid := uconn.GetSessionID(session.SessionTicket())
124					uconn.HandshakeState.Hello.SessionId = sid[:]
125					return nil
126				}
127			}
128			var sessionID [32]byte
129			_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
130			if err != nil {
131				return err
132			}
133			uconn.HandshakeState.Hello.SessionId = sessionID[:]
134		}
135		return nil
136	}
137	return nil
138}
139
140// If you want session tickets to be reused - use same cache on following connections
141func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
142	uconn.config.ClientSessionCache = cache
143	uconn.HandshakeState.Hello.TicketSupported = true
144}
145
146// SetClientRandom sets client random explicitly.
147// BuildHandshakeFirst() must be called before SetClientRandom.
148// r must to be 32 bytes long.
149func (uconn *UConn) SetClientRandom(r []byte) error {
150	if len(r) != 32 {
151		return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
152	} else {
153		uconn.HandshakeState.Hello.Random = make([]byte, 32)
154		copy(uconn.HandshakeState.Hello.Random, r)
155		return nil
156	}
157}
158
159func (uconn *UConn) SetSNI(sni string) {
160	hname := hostnameInSNI(sni)
161	uconn.config.ServerName = hname
162	for _, ext := range uconn.Extensions {
163		sniExt, ok := ext.(*SNIExtension)
164		if ok {
165			sniExt.ServerName = hname
166		}
167	}
168}
169
170// RemoveSNIExtension removes SNI from the list of extensions sent in ClientHello
171// It returns an error when used with HelloGolang ClientHelloID
172func (uconn *UConn) RemoveSNIExtension() error {
173	if uconn.ClientHelloID == HelloGolang {
174		return fmt.Errorf("Cannot call RemoveSNIExtension on a UConn with a HelloGolang ClientHelloID")
175	}
176	uconn.omitSNIExtension = true
177	return nil
178}
179
180func (uconn *UConn) removeSNIExtension() {
181	filteredExts := make([]TLSExtension, 0, len(uconn.Extensions))
182	for _, e := range uconn.Extensions {
183		if _, ok := e.(*SNIExtension); !ok {
184			filteredExts = append(filteredExts, e)
185		}
186	}
187	uconn.Extensions = filteredExts
188}
189
190// Handshake runs the client handshake using given clientHandshakeState
191// Requires hs.hello, and, optionally, hs.session to be set.
192func (c *UConn) Handshake() error {
193	c.handshakeMutex.Lock()
194	defer c.handshakeMutex.Unlock()
195
196	if err := c.handshakeErr; err != nil {
197		return err
198	}
199	if c.handshakeComplete() {
200		return nil
201	}
202
203	c.in.Lock()
204	defer c.in.Unlock()
205
206	if c.isClient {
207		// [uTLS section begins]
208		err := c.BuildHandshakeState()
209		if err != nil {
210			return err
211		}
212		// [uTLS section ends]
213
214		c.handshakeErr = c.clientHandshake()
215	} else {
216		c.handshakeErr = c.serverHandshake()
217	}
218	if c.handshakeErr == nil {
219		c.handshakes++
220	} else {
221		// If an error occurred during the hadshake try to flush the
222		// alert that might be left in the buffer.
223		c.flush()
224	}
225
226	if c.handshakeErr == nil && !c.handshakeComplete() {
227		c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
228	}
229
230	return c.handshakeErr
231}
232
233// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
234// Write writes data to the connection.
235func (c *UConn) Write(b []byte) (int, error) {
236	// interlock with Close below
237	for {
238		x := atomic.LoadInt32(&c.activeCall)
239		if x&1 != 0 {
240			return 0, errClosed
241		}
242		if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
243			defer atomic.AddInt32(&c.activeCall, -2)
244			break
245		}
246	}
247
248	if err := c.Handshake(); err != nil {
249		return 0, err
250	}
251
252	c.out.Lock()
253	defer c.out.Unlock()
254
255	if err := c.out.err; err != nil {
256		return 0, err
257	}
258
259	if !c.handshakeComplete() {
260		return 0, alertInternalError
261	}
262
263	if c.closeNotifySent {
264		return 0, errShutdown
265	}
266
267	// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
268	// attack when using block mode ciphers due to predictable IVs.
269	// This can be prevented by splitting each Application Data
270	// record into two records, effectively randomizing the IV.
271	//
272	// https://www.openssl.org/~bodo/tls-cbc.txt
273	// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
274	// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
275
276	var m int
277	if len(b) > 1 && c.vers <= VersionTLS10 {
278		if _, ok := c.out.cipher.(cipher.BlockMode); ok {
279			n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
280			if err != nil {
281				return n, c.out.setErrorLocked(err)
282			}
283			m, b = 1, b[1:]
284		}
285	}
286
287	n, err := c.writeRecordLocked(recordTypeApplicationData, b)
288	return n + m, c.out.setErrorLocked(err)
289}
290
291// clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3)
292// and performs client TLS handshake with that state
293func (c *UConn) clientHandshake() (err error) {
294	// [uTLS section begins]
295	hello := c.HandshakeState.Hello.getPrivatePtr()
296	defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
297
298	sessionIsAlreadySet := c.HandshakeState.Session != nil
299
300	// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
301	// useTLS13 variable tells which pointer
302	// [uTLS section ends]
303
304	if c.config == nil {
305		c.config = defaultConfig()
306	}
307
308	// This may be a renegotiation handshake, in which case some fields
309	// need to be reset.
310	c.didResume = false
311
312	// [uTLS section begins]
313	// don't make new ClientHello, use hs.hello
314	// preserve the checks from beginning and end of makeClientHello()
315	if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
316		return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
317	}
318
319	nextProtosLength := 0
320	for _, proto := range c.config.NextProtos {
321		if l := len(proto); l == 0 || l > 255 {
322			return errors.New("tls: invalid NextProtos value")
323		} else {
324			nextProtosLength += 1 + l
325		}
326	}
327
328	if nextProtosLength > 0xffff {
329		return errors.New("tls: NextProtos values too large")
330	}
331
332	if c.handshakes > 0 {
333		hello.secureRenegotiation = c.clientFinished[:]
334	}
335	// [uTLS section ends]
336
337	cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
338	if cacheKey != "" && session != nil {
339		defer func() {
340			// If we got a handshake failure when resuming a session, throw away
341			// the session ticket. See RFC 5077, Section 3.2.
342			//
343			// RFC 8446 makes no mention of dropping tickets on failure, but it
344			// does require servers to abort on invalid binders, so we need to
345			// delete tickets to recover from a corrupted PSK.
346			if err != nil {
347				c.config.ClientSessionCache.Put(cacheKey, nil)
348			}
349		}()
350	}
351
352	if !sessionIsAlreadySet { // uTLS: do not overwrite already set session
353		err = c.SetSessionState(session)
354		if err != nil {
355			return
356		}
357	}
358
359	if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
360		return err
361	}
362
363	msg, err := c.readHandshake()
364	if err != nil {
365		return err
366	}
367
368	serverHello, ok := msg.(*serverHelloMsg)
369	if !ok {
370		c.sendAlert(alertUnexpectedMessage)
371		return unexpectedMessageError(serverHello, msg)
372	}
373
374	if err := c.pickTLSVersion(serverHello); err != nil {
375		return err
376	}
377
378	// uTLS: do not create new handshakeState, use existing one
379	if c.vers == VersionTLS13 {
380		hs13 := c.HandshakeState.toPrivate13()
381		hs13.serverHello = serverHello
382		hs13.hello = hello
383		if !sessionIsAlreadySet {
384			hs13.earlySecret = earlySecret
385			hs13.binderKey = binderKey
386		}
387		// In TLS 1.3, session tickets are delivered after the handshake.
388		err = hs13.handshake()
389		if handshakeState := hs13.toPublic13(); handshakeState != nil {
390			c.HandshakeState = *handshakeState
391		}
392		return err
393	}
394
395	hs12 := c.HandshakeState.toPrivate12()
396	hs12.serverHello = serverHello
397	hs12.hello = hello
398	err = hs12.handshake()
399	if handshakeState := hs12.toPublic12(); handshakeState != nil {
400		c.HandshakeState = *handshakeState
401	}
402	if err != nil {
403		return err
404	}
405
406	// If we had a successful handshake and hs.session is different from
407	// the one already cached - cache a new one.
408	if cacheKey != "" && hs12.session != nil && session != hs12.session {
409		c.config.ClientSessionCache.Put(cacheKey, hs12.session)
410	}
411	return nil
412}
413
414func (uconn *UConn) ApplyConfig() error {
415	for _, ext := range uconn.Extensions {
416		err := ext.writeToUConn(uconn)
417		if err != nil {
418			return err
419		}
420	}
421	return nil
422}
423
424func (uconn *UConn) MarshalClientHello() error {
425	hello := uconn.HandshakeState.Hello
426	headerLength := 2 + 32 + 1 + len(hello.SessionId) +
427		2 + len(hello.CipherSuites)*2 +
428		1 + len(hello.CompressionMethods)
429
430	extensionsLen := 0
431	var paddingExt *UtlsPaddingExtension
432	for _, ext := range uconn.Extensions {
433		if pe, ok := ext.(*UtlsPaddingExtension); !ok {
434			// If not padding - just add length of extension to total length
435			extensionsLen += ext.Len()
436		} else {
437			// If padding - process it later
438			if paddingExt == nil {
439				paddingExt = pe
440			} else {
441				return errors.New("Multiple padding extensions!")
442			}
443		}
444	}
445
446	if paddingExt != nil {
447		// determine padding extension presence and length
448		paddingExt.Update(headerLength + 4 + extensionsLen + 2)
449		extensionsLen += paddingExt.Len()
450	}
451
452	helloLen := headerLength
453	if len(uconn.Extensions) > 0 {
454		helloLen += 2 + extensionsLen // 2 bytes for extensions' length
455	}
456
457	helloBuffer := bytes.Buffer{}
458	bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
459	// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
460	// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
461
462	binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
463	helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
464	binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
465	binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
466
467	binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
468
469	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
470	binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
471
472	binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
473	for _, suite := range hello.CipherSuites {
474		binary.Write(bufferedWriter, binary.BigEndian, suite)
475	}
476
477	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
478	binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
479
480	if len(uconn.Extensions) > 0 {
481		binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
482		for _, ext := range uconn.Extensions {
483			bufferedWriter.ReadFrom(ext)
484		}
485	}
486
487	err := bufferedWriter.Flush()
488	if err != nil {
489		return err
490	}
491
492	if helloBuffer.Len() != 4+helloLen {
493		return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
494			". Got: " + strconv.Itoa(helloBuffer.Len()))
495	}
496
497	hello.Raw = helloBuffer.Bytes()
498	return nil
499}
500
501// get current state of cipher and encrypt zeros to get keystream
502func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
503	zeros := make([]byte, length)
504
505	if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
506		// AEAD.Seal() does not mutate internal state, other ciphers might
507		return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
508	}
509	return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
510}
511
512// SetTLSVers sets min and max TLS version in all appropriate places.
513// Function will use first non-zero version parsed in following order:
514//   1) Provided minTLSVers, maxTLSVers
515//   2) specExtensions may have SupportedVersionsExtension
516//   3) [default] min = TLS 1.0, max = TLS 1.2
517//
518// Error is only returned if things are in clearly undesirable state
519// to help user fix them.
520func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error {
521	if minTLSVers == 0 && maxTLSVers == 0 {
522		// if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension
523		supportedVersionsExtensionsPresent := 0
524		for _, e := range specExtensions {
525			switch ext := e.(type) {
526			case *SupportedVersionsExtension:
527				findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) {
528					// returns (minVers, maxVers)
529					minVers := uint16(0)
530					maxVers := uint16(0)
531					for _, vers := range versions {
532						if vers == GREASE_PLACEHOLDER {
533							continue
534						}
535						if maxVers < vers || maxVers == 0 {
536							maxVers = vers
537						}
538						if minVers > vers || minVers == 0 {
539							minVers = vers
540						}
541					}
542					return minVers, maxVers
543				}
544
545				supportedVersionsExtensionsPresent += 1
546				minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions)
547				if minTLSVers == 0 && maxTLSVers == 0 {
548					return fmt.Errorf("SupportedVersions extension has invalid Versions field")
549				} // else: proceed
550			}
551		}
552		switch supportedVersionsExtensionsPresent {
553		case 0:
554			// if mandatory for TLS 1.3 extension is not present, just default to 1.2
555			minTLSVers = VersionTLS10
556			maxTLSVers = VersionTLS12
557		case 1:
558		default:
559			return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions",
560				supportedVersionsExtensionsPresent)
561		}
562	}
563
564	if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 {
565		return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
566	}
567
568	if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
569		return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
570	}
571
572	uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
573	uconn.config.MinVersion = minTLSVers
574	uconn.config.MaxVersion = maxTLSVers
575
576	return nil
577}
578
579func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
580	uconn.Conn.conn = c
581}
582
583func (uconn *UConn) GetUnderlyingConn() net.Conn {
584	return uconn.Conn.conn
585}
586
587// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
588// Major Hack Alert.
589func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
590	tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
591	cs := cipherSuiteByID(cipherSuite)
592	if cs == nil {
593		return nil
594	}
595
596	// This is mostly borrowed from establishKeys()
597	clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
598		keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
599			cs.macLen, cs.keyLen, cs.ivLen)
600
601	var clientCipher, serverCipher interface{}
602	var clientHash, serverHash macFunction
603	if cs.cipher != nil {
604		clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
605		clientHash = cs.mac(version, clientMAC)
606		serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
607		serverHash = cs.mac(version, serverMAC)
608	} else {
609		clientCipher = cs.aead(clientKey, clientIV)
610		serverCipher = cs.aead(serverKey, serverIV)
611	}
612
613	if isClient {
614		tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
615		tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
616	} else {
617		tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
618		tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
619	}
620
621	// skip the handshake states
622	tlsConn.handshakeStatus = 1
623	tlsConn.cipherSuite = cipherSuite
624	tlsConn.haveVers = true
625	tlsConn.vers = version
626
627	// Update to the new cipher specs
628	// and consume the finished messages
629	tlsConn.in.changeCipherSpec()
630	tlsConn.out.changeCipherSpec()
631
632	tlsConn.in.incSeq()
633	tlsConn.out.incSeq()
634
635	return tlsConn
636}
637
638func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
639	a := make([]uint16, maxVers-minVers+1)
640	for i := range a {
641		a[i] = maxVers - uint16(i)
642	}
643	return a
644}
645