1// Copyright 2017 Google Inc. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bufio"
9	"bytes"
10	"crypto/cipher"
11	"encoding/binary"
12	"errors"
13	"fmt"
14	"io"
15	"net"
16	"strconv"
17	"sync/atomic"
18)
19
20type UConn struct {
21	*Conn
22
23	Extensions    []TLSExtension
24	ClientHelloID ClientHelloID
25
26	ClientHelloBuilt bool
27	HandshakeState   ClientHandshakeState
28
29	// sessionID may or may not depend on ticket; nil => random
30	GetSessionID func(ticket []byte) [32]byte
31
32	greaseSeed [ssl_grease_last_index]uint16
33}
34
35// UClient returns a new uTLS client, with behavior depending on clientHelloID.
36// Config CAN be nil, but make sure to eventually specify ServerName.
37func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn {
38	if config == nil {
39		config = &Config{}
40	}
41	tlsConn := Conn{conn: conn, config: config, isClient: true}
42	handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}}
43	uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState}
44	uconn.HandshakeState.uconn = &uconn
45	return &uconn
46}
47
48// BuildHandshakeState behavior varies based on ClientHelloID and
49// whether it was already called before.
50// If HelloGolang:
51//   [only once] make default ClientHello and overwrite existing state
52// If any other mimicking ClientHelloID is used:
53//   [only once] make ClientHello based on ID and overwrite existing state
54//   [each call] apply uconn.Extensions config to internal crypto/tls structures
55//   [each call] marshal ClientHello.
56//
57// BuildHandshakeState is automatically called before uTLS performs handshake,
58// amd should only be called explicitly to inspect/change fields of
59// default/mimicked ClientHello.
60func (uconn *UConn) BuildHandshakeState() error {
61	if uconn.ClientHelloID == HelloGolang {
62		if uconn.ClientHelloBuilt {
63			return nil
64		}
65
66		// use default Golang ClientHello.
67		hello, ecdheParams, err := uconn.makeClientHello()
68		if err != nil {
69			return err
70		}
71
72		uconn.HandshakeState.Hello = hello.getPublicPtr()
73		uconn.HandshakeState.State13.EcdheParams = ecdheParams
74		uconn.HandshakeState.C = uconn.Conn
75	} else {
76		if !uconn.ClientHelloBuilt {
77			err := uconn.applyPresetByID(uconn.ClientHelloID)
78			if err != nil {
79				return err
80			}
81		}
82
83		err := uconn.ApplyConfig()
84		if err != nil {
85			return err
86		}
87		err = uconn.MarshalClientHello()
88		if err != nil {
89			return err
90		}
91	}
92	uconn.ClientHelloBuilt = true
93	return nil
94}
95
96// SetSessionState sets the session ticket, which may be preshared or fake.
97// If session is nil, the body of session ticket extension will be unset,
98// but the extension itself still MAY be present for mimicking purposes.
99// Session tickets to be reused - use same cache on following connections.
100func (uconn *UConn) SetSessionState(session *ClientSessionState) error {
101	uconn.HandshakeState.Session = session
102	var sessionTicket []uint8
103	if session != nil {
104		sessionTicket = session.sessionTicket
105	}
106	uconn.HandshakeState.Hello.TicketSupported = true
107	uconn.HandshakeState.Hello.SessionTicket = sessionTicket
108
109	for _, ext := range uconn.Extensions {
110		st, ok := ext.(*SessionTicketExtension)
111		if !ok {
112			continue
113		}
114		st.Session = session
115		if session != nil {
116			if len(session.SessionTicket()) > 0 {
117				if uconn.GetSessionID != nil {
118					sid := uconn.GetSessionID(session.SessionTicket())
119					uconn.HandshakeState.Hello.SessionId = sid[:]
120					return nil
121				}
122			}
123			var sessionID [32]byte
124			_, err := io.ReadFull(uconn.config.rand(), sessionID[:])
125			if err != nil {
126				return err
127			}
128			uconn.HandshakeState.Hello.SessionId = sessionID[:]
129		}
130		return nil
131	}
132	return nil
133}
134
135// If you want session tickets to be reused - use same cache on following connections
136func (uconn *UConn) SetSessionCache(cache ClientSessionCache) {
137	uconn.config.ClientSessionCache = cache
138	uconn.HandshakeState.Hello.TicketSupported = true
139}
140
141// SetClientRandom sets client random explicitly.
142// BuildHandshakeFirst() must be called before SetClientRandom.
143// r must to be 32 bytes long.
144func (uconn *UConn) SetClientRandom(r []byte) error {
145	if len(r) != 32 {
146		return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r)))
147	} else {
148		uconn.HandshakeState.Hello.Random = make([]byte, 32)
149		copy(uconn.HandshakeState.Hello.Random, r)
150		return nil
151	}
152}
153
154func (uconn *UConn) SetSNI(sni string) {
155	hname := hostnameInSNI(sni)
156	uconn.config.ServerName = hname
157	for _, ext := range uconn.Extensions {
158		sniExt, ok := ext.(*SNIExtension)
159		if ok {
160			sniExt.ServerName = hname
161		}
162	}
163}
164
165// Handshake runs the client handshake using given clientHandshakeState
166// Requires hs.hello, and, optionally, hs.session to be set.
167func (c *UConn) Handshake() error {
168	c.handshakeMutex.Lock()
169	defer c.handshakeMutex.Unlock()
170
171	if err := c.handshakeErr; err != nil {
172		return err
173	}
174	if c.handshakeComplete() {
175		return nil
176	}
177
178	c.in.Lock()
179	defer c.in.Unlock()
180
181	if c.isClient {
182		// [uTLS section begins]
183		err := c.BuildHandshakeState()
184		if err != nil {
185			return err
186		}
187		// [uTLS section ends]
188
189		c.handshakeErr = c.clientHandshake()
190	} else {
191		c.handshakeErr = c.serverHandshake()
192	}
193	if c.handshakeErr == nil {
194		c.handshakes++
195	} else {
196		// If an error occurred during the hadshake try to flush the
197		// alert that might be left in the buffer.
198		c.flush()
199	}
200
201	if c.handshakeErr == nil && !c.handshakeComplete() {
202		c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
203	}
204
205	return c.handshakeErr
206}
207
208// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls.
209// Write writes data to the connection.
210func (c *UConn) Write(b []byte) (int, error) {
211	// interlock with Close below
212	for {
213		x := atomic.LoadInt32(&c.activeCall)
214		if x&1 != 0 {
215			return 0, errClosed
216		}
217		if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
218			defer atomic.AddInt32(&c.activeCall, -2)
219			break
220		}
221	}
222
223	if err := c.Handshake(); err != nil {
224		return 0, err
225	}
226
227	c.out.Lock()
228	defer c.out.Unlock()
229
230	if err := c.out.err; err != nil {
231		return 0, err
232	}
233
234	if !c.handshakeComplete() {
235		return 0, alertInternalError
236	}
237
238	if c.closeNotifySent {
239		return 0, errShutdown
240	}
241
242	// SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext
243	// attack when using block mode ciphers due to predictable IVs.
244	// This can be prevented by splitting each Application Data
245	// record into two records, effectively randomizing the IV.
246	//
247	// https://www.openssl.org/~bodo/tls-cbc.txt
248	// https://bugzilla.mozilla.org/show_bug.cgi?id=665814
249	// https://www.imperialviolet.org/2012/01/15/beastfollowup.html
250
251	var m int
252	if len(b) > 1 && c.vers <= VersionTLS10 {
253		if _, ok := c.out.cipher.(cipher.BlockMode); ok {
254			n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
255			if err != nil {
256				return n, c.out.setErrorLocked(err)
257			}
258			m, b = 1, b[1:]
259		}
260	}
261
262	n, err := c.writeRecordLocked(recordTypeApplicationData, b)
263	return n + m, c.out.setErrorLocked(err)
264}
265
266// clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3)
267// and performs client TLS handshake with that state
268func (c *UConn) clientHandshake() (err error) {
269	// [uTLS section begins]
270	hello := c.HandshakeState.Hello.getPrivatePtr()
271	defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }()
272
273	sessionIsAlreadySet := c.HandshakeState.Session != nil
274
275	// after this point exactly 1 out of 2 HandshakeState pointers is non-nil,
276	// useTLS13 variable tells which pointer
277	// [uTLS section ends]
278
279	if c.config == nil {
280		c.config = defaultConfig()
281	}
282
283	// This may be a renegotiation handshake, in which case some fields
284	// need to be reset.
285	c.didResume = false
286
287	// [uTLS section begins]
288	// don't make new ClientHello, use hs.hello
289	// preserve the checks from beginning and end of makeClientHello()
290	if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify {
291		return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
292	}
293
294	nextProtosLength := 0
295	for _, proto := range c.config.NextProtos {
296		if l := len(proto); l == 0 || l > 255 {
297			return errors.New("tls: invalid NextProtos value")
298		} else {
299			nextProtosLength += 1 + l
300		}
301	}
302
303	if nextProtosLength > 0xffff {
304		return errors.New("tls: NextProtos values too large")
305	}
306
307	if c.handshakes > 0 {
308		hello.secureRenegotiation = c.clientFinished[:]
309	}
310	// [uTLS section ends]
311
312	cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
313	if cacheKey != "" && session != nil {
314		defer func() {
315			// If we got a handshake failure when resuming a session, throw away
316			// the session ticket. See RFC 5077, Section 3.2.
317			//
318			// RFC 8446 makes no mention of dropping tickets on failure, but it
319			// does require servers to abort on invalid binders, so we need to
320			// delete tickets to recover from a corrupted PSK.
321			if err != nil {
322				c.config.ClientSessionCache.Put(cacheKey, nil)
323			}
324		}()
325	}
326
327	if !sessionIsAlreadySet { // uTLS: do not overwrite already set session
328		err = c.SetSessionState(session)
329		if err != nil {
330			return
331		}
332	}
333
334	if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
335		return err
336	}
337
338	msg, err := c.readHandshake()
339	if err != nil {
340		return err
341	}
342
343	serverHello, ok := msg.(*serverHelloMsg)
344	if !ok {
345		c.sendAlert(alertUnexpectedMessage)
346		return unexpectedMessageError(serverHello, msg)
347	}
348
349	if err := c.pickTLSVersion(serverHello); err != nil {
350		return err
351	}
352
353	// uTLS: do not create new handshakeState, use existing one
354	if c.vers == VersionTLS13 {
355		hs13 := c.HandshakeState.toPrivate13()
356		hs13.serverHello = serverHello
357		hs13.hello = hello
358		if !sessionIsAlreadySet {
359			hs13.earlySecret = earlySecret
360			hs13.binderKey = binderKey
361		}
362		// In TLS 1.3, session tickets are delivered after the handshake.
363		err = hs13.handshake()
364		if handshakeState := hs13.toPublic13(); handshakeState != nil {
365			c.HandshakeState = *handshakeState
366		}
367		return err
368	}
369
370	hs12 := c.HandshakeState.toPrivate12()
371	hs12.serverHello = serverHello
372	hs12.hello = hello
373	err = hs12.handshake()
374	if handshakeState := hs12.toPublic12(); handshakeState != nil {
375		c.HandshakeState = *handshakeState
376	}
377	if err != nil {
378		return err
379	}
380
381	// If we had a successful handshake and hs.session is different from
382	// the one already cached - cache a new one.
383	if cacheKey != "" && hs12.session != nil && session != hs12.session {
384		c.config.ClientSessionCache.Put(cacheKey, hs12.session)
385	}
386	return nil
387}
388
389func (uconn *UConn) ApplyConfig() error {
390	for _, ext := range uconn.Extensions {
391		err := ext.writeToUConn(uconn)
392		if err != nil {
393			return err
394		}
395	}
396	return nil
397}
398
399func (uconn *UConn) MarshalClientHello() error {
400	hello := uconn.HandshakeState.Hello
401	headerLength := 2 + 32 + 1 + len(hello.SessionId) +
402		2 + len(hello.CipherSuites)*2 +
403		1 + len(hello.CompressionMethods)
404
405	extensionsLen := 0
406	var paddingExt *UtlsPaddingExtension
407	for _, ext := range uconn.Extensions {
408		if pe, ok := ext.(*UtlsPaddingExtension); !ok {
409			// If not padding - just add length of extension to total length
410			extensionsLen += ext.Len()
411		} else {
412			// If padding - process it later
413			if paddingExt == nil {
414				paddingExt = pe
415			} else {
416				return errors.New("Multiple padding extensions!")
417			}
418		}
419	}
420
421	if paddingExt != nil {
422		// determine padding extension presence and length
423		paddingExt.Update(headerLength + 4 + extensionsLen + 2)
424		extensionsLen += paddingExt.Len()
425	}
426
427	helloLen := headerLength
428	if len(uconn.Extensions) > 0 {
429		helloLen += 2 + extensionsLen // 2 bytes for extensions' length
430	}
431
432	helloBuffer := bytes.Buffer{}
433	bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length
434	// We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens
435	// Write() will become noop, and error will be accessible via Flush(), which is called once in the end
436
437	binary.Write(bufferedWriter, binary.BigEndian, typeClientHello)
438	helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24
439	binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes)
440	binary.Write(bufferedWriter, binary.BigEndian, hello.Vers)
441
442	binary.Write(bufferedWriter, binary.BigEndian, hello.Random)
443
444	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId)))
445	binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId)
446
447	binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1))
448	for _, suite := range hello.CipherSuites {
449		binary.Write(bufferedWriter, binary.BigEndian, suite)
450	}
451
452	binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods)))
453	binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods)
454
455	if len(uconn.Extensions) > 0 {
456		binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen))
457		for _, ext := range uconn.Extensions {
458			bufferedWriter.ReadFrom(ext)
459		}
460	}
461
462	err := bufferedWriter.Flush()
463	if err != nil {
464		return err
465	}
466
467	if helloBuffer.Len() != 4+helloLen {
468		return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) +
469			". Got: " + strconv.Itoa(helloBuffer.Len()))
470	}
471
472	hello.Raw = helloBuffer.Bytes()
473	return nil
474}
475
476// get current state of cipher and encrypt zeros to get keystream
477func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) {
478	zeros := make([]byte, length)
479
480	if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok {
481		// AEAD.Seal() does not mutate internal state, other ciphers might
482		return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil
483	}
484	return nil, errors.New("Could not convert OutCipher to cipher.AEAD")
485}
486
487// SetTLSVers sets min and max TLS version in all appropriate places.
488// Function will use first non-zero version parsed in following order:
489//   1) Provided minTLSVers, maxTLSVers
490//   2) specExtensions may have SupportedVersionsExtension
491//   3) [default] min = TLS 1.0, max = TLS 1.2
492//
493// Error is only returned if things are in clearly undesirable state
494// to help user fix them.
495func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error {
496	if minTLSVers == 0 && maxTLSVers == 0 {
497		// if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension
498		supportedVersionsExtensionsPresent := 0
499		for _, e := range specExtensions {
500			switch ext := e.(type) {
501			case *SupportedVersionsExtension:
502				findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) {
503					// returns (minVers, maxVers)
504					minVers := uint16(0)
505					maxVers := uint16(0)
506					for _, vers := range versions {
507						if vers == GREASE_PLACEHOLDER {
508							continue
509						}
510						if maxVers < vers || maxVers == 0 {
511							maxVers = vers
512						}
513						if minVers > vers || minVers == 0 {
514							minVers = vers
515						}
516					}
517					return minVers, maxVers
518				}
519
520				supportedVersionsExtensionsPresent += 1
521				minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions)
522				if minTLSVers == 0 && maxTLSVers == 0 {
523					return fmt.Errorf("SupportedVersions extension has invalid Versions field")
524				} // else: proceed
525			}
526		}
527		switch supportedVersionsExtensionsPresent {
528		case 0:
529			// if mandatory for TLS 1.3 extension is not present, just default to 1.2
530			minTLSVers = VersionTLS10
531			maxTLSVers = VersionTLS12
532		case 1:
533		default:
534			return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions",
535				supportedVersionsExtensionsPresent)
536		}
537	}
538
539	if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 {
540		return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers)
541	}
542
543	if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 {
544		return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers)
545	}
546
547	uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers)
548	uconn.config.MinVersion = minTLSVers
549	uconn.config.MaxVersion = maxTLSVers
550
551	return nil
552}
553
554func (uconn *UConn) SetUnderlyingConn(c net.Conn) {
555	uconn.Conn.conn = c
556}
557
558func (uconn *UConn) GetUnderlyingConn() net.Conn {
559	return uconn.Conn.conn
560}
561
562// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections.
563// Major Hack Alert.
564func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn {
565	tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient}
566	cs := cipherSuiteByID(cipherSuite)
567	if cs == nil {
568		return nil
569	}
570
571	// This is mostly borrowed from establishKeys()
572	clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
573		keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom,
574			cs.macLen, cs.keyLen, cs.ivLen)
575
576	var clientCipher, serverCipher interface{}
577	var clientHash, serverHash macFunction
578	if cs.cipher != nil {
579		clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */)
580		clientHash = cs.mac(version, clientMAC)
581		serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */)
582		serverHash = cs.mac(version, serverMAC)
583	} else {
584		clientCipher = cs.aead(clientKey, clientIV)
585		serverCipher = cs.aead(serverKey, serverIV)
586	}
587
588	if isClient {
589		tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash)
590		tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash)
591	} else {
592		tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash)
593		tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash)
594	}
595
596	// skip the handshake states
597	tlsConn.handshakeStatus = 1
598	tlsConn.cipherSuite = cipherSuite
599	tlsConn.haveVers = true
600	tlsConn.vers = version
601
602	// Update to the new cipher specs
603	// and consume the finished messages
604	tlsConn.in.changeCipherSpec()
605	tlsConn.out.changeCipherSpec()
606
607	tlsConn.in.incSeq()
608	tlsConn.out.incSeq()
609
610	return tlsConn
611}
612
613func makeSupportedVersions(minVers, maxVers uint16) []uint16 {
614	a := make([]uint16, maxVers-minVers+1)
615	for i := range a {
616		a[i] = maxVers - uint16(i)
617	}
618	return a
619}
620