1// Copyright 2017 Google Inc. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package tls 6 7import ( 8 "bufio" 9 "bytes" 10 "crypto/cipher" 11 "encoding/binary" 12 "errors" 13 "fmt" 14 "io" 15 "net" 16 "strconv" 17 "sync/atomic" 18) 19 20type UConn struct { 21 *Conn 22 23 Extensions []TLSExtension 24 ClientHelloID ClientHelloID 25 26 ClientHelloBuilt bool 27 HandshakeState ClientHandshakeState 28 29 // sessionID may or may not depend on ticket; nil => random 30 GetSessionID func(ticket []byte) [32]byte 31 32 greaseSeed [ssl_grease_last_index]uint16 33} 34 35// UClient returns a new uTLS client, with behavior depending on clientHelloID. 36// Config CAN be nil, but make sure to eventually specify ServerName. 37func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn { 38 if config == nil { 39 config = &Config{} 40 } 41 tlsConn := Conn{conn: conn, config: config, isClient: true} 42 handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}} 43 uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState} 44 uconn.HandshakeState.uconn = &uconn 45 return &uconn 46} 47 48// BuildHandshakeState behavior varies based on ClientHelloID and 49// whether it was already called before. 50// If HelloGolang: 51// [only once] make default ClientHello and overwrite existing state 52// If any other mimicking ClientHelloID is used: 53// [only once] make ClientHello based on ID and overwrite existing state 54// [each call] apply uconn.Extensions config to internal crypto/tls structures 55// [each call] marshal ClientHello. 56// 57// BuildHandshakeState is automatically called before uTLS performs handshake, 58// amd should only be called explicitly to inspect/change fields of 59// default/mimicked ClientHello. 60func (uconn *UConn) BuildHandshakeState() error { 61 if uconn.ClientHelloID == HelloGolang { 62 if uconn.ClientHelloBuilt { 63 return nil 64 } 65 66 // use default Golang ClientHello. 67 hello, ecdheParams, err := uconn.makeClientHello() 68 if err != nil { 69 return err 70 } 71 72 uconn.HandshakeState.Hello = hello.getPublicPtr() 73 uconn.HandshakeState.State13.EcdheParams = ecdheParams 74 uconn.HandshakeState.C = uconn.Conn 75 } else { 76 if !uconn.ClientHelloBuilt { 77 err := uconn.applyPresetByID(uconn.ClientHelloID) 78 if err != nil { 79 return err 80 } 81 } 82 83 err := uconn.ApplyConfig() 84 if err != nil { 85 return err 86 } 87 err = uconn.MarshalClientHello() 88 if err != nil { 89 return err 90 } 91 } 92 uconn.ClientHelloBuilt = true 93 return nil 94} 95 96// SetSessionState sets the session ticket, which may be preshared or fake. 97// If session is nil, the body of session ticket extension will be unset, 98// but the extension itself still MAY be present for mimicking purposes. 99// Session tickets to be reused - use same cache on following connections. 100func (uconn *UConn) SetSessionState(session *ClientSessionState) error { 101 uconn.HandshakeState.Session = session 102 var sessionTicket []uint8 103 if session != nil { 104 sessionTicket = session.sessionTicket 105 } 106 uconn.HandshakeState.Hello.TicketSupported = true 107 uconn.HandshakeState.Hello.SessionTicket = sessionTicket 108 109 for _, ext := range uconn.Extensions { 110 st, ok := ext.(*SessionTicketExtension) 111 if !ok { 112 continue 113 } 114 st.Session = session 115 if session != nil { 116 if len(session.SessionTicket()) > 0 { 117 if uconn.GetSessionID != nil { 118 sid := uconn.GetSessionID(session.SessionTicket()) 119 uconn.HandshakeState.Hello.SessionId = sid[:] 120 return nil 121 } 122 } 123 var sessionID [32]byte 124 _, err := io.ReadFull(uconn.config.rand(), sessionID[:]) 125 if err != nil { 126 return err 127 } 128 uconn.HandshakeState.Hello.SessionId = sessionID[:] 129 } 130 return nil 131 } 132 return nil 133} 134 135// If you want session tickets to be reused - use same cache on following connections 136func (uconn *UConn) SetSessionCache(cache ClientSessionCache) { 137 uconn.config.ClientSessionCache = cache 138 uconn.HandshakeState.Hello.TicketSupported = true 139} 140 141// SetClientRandom sets client random explicitly. 142// BuildHandshakeFirst() must be called before SetClientRandom. 143// r must to be 32 bytes long. 144func (uconn *UConn) SetClientRandom(r []byte) error { 145 if len(r) != 32 { 146 return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r))) 147 } else { 148 uconn.HandshakeState.Hello.Random = make([]byte, 32) 149 copy(uconn.HandshakeState.Hello.Random, r) 150 return nil 151 } 152} 153 154func (uconn *UConn) SetSNI(sni string) { 155 hname := hostnameInSNI(sni) 156 uconn.config.ServerName = hname 157 for _, ext := range uconn.Extensions { 158 sniExt, ok := ext.(*SNIExtension) 159 if ok { 160 sniExt.ServerName = hname 161 } 162 } 163} 164 165// Handshake runs the client handshake using given clientHandshakeState 166// Requires hs.hello, and, optionally, hs.session to be set. 167func (c *UConn) Handshake() error { 168 c.handshakeMutex.Lock() 169 defer c.handshakeMutex.Unlock() 170 171 if err := c.handshakeErr; err != nil { 172 return err 173 } 174 if c.handshakeComplete() { 175 return nil 176 } 177 178 c.in.Lock() 179 defer c.in.Unlock() 180 181 if c.isClient { 182 // [uTLS section begins] 183 err := c.BuildHandshakeState() 184 if err != nil { 185 return err 186 } 187 // [uTLS section ends] 188 189 c.handshakeErr = c.clientHandshake() 190 } else { 191 c.handshakeErr = c.serverHandshake() 192 } 193 if c.handshakeErr == nil { 194 c.handshakes++ 195 } else { 196 // If an error occurred during the hadshake try to flush the 197 // alert that might be left in the buffer. 198 c.flush() 199 } 200 201 if c.handshakeErr == nil && !c.handshakeComplete() { 202 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") 203 } 204 205 return c.handshakeErr 206} 207 208// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls. 209// Write writes data to the connection. 210func (c *UConn) Write(b []byte) (int, error) { 211 // interlock with Close below 212 for { 213 x := atomic.LoadInt32(&c.activeCall) 214 if x&1 != 0 { 215 return 0, errClosed 216 } 217 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { 218 defer atomic.AddInt32(&c.activeCall, -2) 219 break 220 } 221 } 222 223 if err := c.Handshake(); err != nil { 224 return 0, err 225 } 226 227 c.out.Lock() 228 defer c.out.Unlock() 229 230 if err := c.out.err; err != nil { 231 return 0, err 232 } 233 234 if !c.handshakeComplete() { 235 return 0, alertInternalError 236 } 237 238 if c.closeNotifySent { 239 return 0, errShutdown 240 } 241 242 // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext 243 // attack when using block mode ciphers due to predictable IVs. 244 // This can be prevented by splitting each Application Data 245 // record into two records, effectively randomizing the IV. 246 // 247 // https://www.openssl.org/~bodo/tls-cbc.txt 248 // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 249 // https://www.imperialviolet.org/2012/01/15/beastfollowup.html 250 251 var m int 252 if len(b) > 1 && c.vers <= VersionTLS10 { 253 if _, ok := c.out.cipher.(cipher.BlockMode); ok { 254 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) 255 if err != nil { 256 return n, c.out.setErrorLocked(err) 257 } 258 m, b = 1, b[1:] 259 } 260 } 261 262 n, err := c.writeRecordLocked(recordTypeApplicationData, b) 263 return n + m, c.out.setErrorLocked(err) 264} 265 266// clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3) 267// and performs client TLS handshake with that state 268func (c *UConn) clientHandshake() (err error) { 269 // [uTLS section begins] 270 hello := c.HandshakeState.Hello.getPrivatePtr() 271 defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }() 272 273 sessionIsAlreadySet := c.HandshakeState.Session != nil 274 275 // after this point exactly 1 out of 2 HandshakeState pointers is non-nil, 276 // useTLS13 variable tells which pointer 277 // [uTLS section ends] 278 279 if c.config == nil { 280 c.config = defaultConfig() 281 } 282 283 // This may be a renegotiation handshake, in which case some fields 284 // need to be reset. 285 c.didResume = false 286 287 // [uTLS section begins] 288 // don't make new ClientHello, use hs.hello 289 // preserve the checks from beginning and end of makeClientHello() 290 if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { 291 return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") 292 } 293 294 nextProtosLength := 0 295 for _, proto := range c.config.NextProtos { 296 if l := len(proto); l == 0 || l > 255 { 297 return errors.New("tls: invalid NextProtos value") 298 } else { 299 nextProtosLength += 1 + l 300 } 301 } 302 303 if nextProtosLength > 0xffff { 304 return errors.New("tls: NextProtos values too large") 305 } 306 307 if c.handshakes > 0 { 308 hello.secureRenegotiation = c.clientFinished[:] 309 } 310 // [uTLS section ends] 311 312 cacheKey, session, earlySecret, binderKey := c.loadSession(hello) 313 if cacheKey != "" && session != nil { 314 defer func() { 315 // If we got a handshake failure when resuming a session, throw away 316 // the session ticket. See RFC 5077, Section 3.2. 317 // 318 // RFC 8446 makes no mention of dropping tickets on failure, but it 319 // does require servers to abort on invalid binders, so we need to 320 // delete tickets to recover from a corrupted PSK. 321 if err != nil { 322 c.config.ClientSessionCache.Put(cacheKey, nil) 323 } 324 }() 325 } 326 327 if !sessionIsAlreadySet { // uTLS: do not overwrite already set session 328 err = c.SetSessionState(session) 329 if err != nil { 330 return 331 } 332 } 333 334 if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { 335 return err 336 } 337 338 msg, err := c.readHandshake() 339 if err != nil { 340 return err 341 } 342 343 serverHello, ok := msg.(*serverHelloMsg) 344 if !ok { 345 c.sendAlert(alertUnexpectedMessage) 346 return unexpectedMessageError(serverHello, msg) 347 } 348 349 if err := c.pickTLSVersion(serverHello); err != nil { 350 return err 351 } 352 353 // uTLS: do not create new handshakeState, use existing one 354 if c.vers == VersionTLS13 { 355 hs13 := c.HandshakeState.toPrivate13() 356 hs13.serverHello = serverHello 357 hs13.hello = hello 358 if !sessionIsAlreadySet { 359 hs13.earlySecret = earlySecret 360 hs13.binderKey = binderKey 361 } 362 // In TLS 1.3, session tickets are delivered after the handshake. 363 err = hs13.handshake() 364 if handshakeState := hs13.toPublic13(); handshakeState != nil { 365 c.HandshakeState = *handshakeState 366 } 367 return err 368 } 369 370 hs12 := c.HandshakeState.toPrivate12() 371 hs12.serverHello = serverHello 372 hs12.hello = hello 373 err = hs12.handshake() 374 if handshakeState := hs12.toPublic12(); handshakeState != nil { 375 c.HandshakeState = *handshakeState 376 } 377 if err != nil { 378 return err 379 } 380 381 // If we had a successful handshake and hs.session is different from 382 // the one already cached - cache a new one. 383 if cacheKey != "" && hs12.session != nil && session != hs12.session { 384 c.config.ClientSessionCache.Put(cacheKey, hs12.session) 385 } 386 return nil 387} 388 389func (uconn *UConn) ApplyConfig() error { 390 for _, ext := range uconn.Extensions { 391 err := ext.writeToUConn(uconn) 392 if err != nil { 393 return err 394 } 395 } 396 return nil 397} 398 399func (uconn *UConn) MarshalClientHello() error { 400 hello := uconn.HandshakeState.Hello 401 headerLength := 2 + 32 + 1 + len(hello.SessionId) + 402 2 + len(hello.CipherSuites)*2 + 403 1 + len(hello.CompressionMethods) 404 405 extensionsLen := 0 406 var paddingExt *UtlsPaddingExtension 407 for _, ext := range uconn.Extensions { 408 if pe, ok := ext.(*UtlsPaddingExtension); !ok { 409 // If not padding - just add length of extension to total length 410 extensionsLen += ext.Len() 411 } else { 412 // If padding - process it later 413 if paddingExt == nil { 414 paddingExt = pe 415 } else { 416 return errors.New("Multiple padding extensions!") 417 } 418 } 419 } 420 421 if paddingExt != nil { 422 // determine padding extension presence and length 423 paddingExt.Update(headerLength + 4 + extensionsLen + 2) 424 extensionsLen += paddingExt.Len() 425 } 426 427 helloLen := headerLength 428 if len(uconn.Extensions) > 0 { 429 helloLen += 2 + extensionsLen // 2 bytes for extensions' length 430 } 431 432 helloBuffer := bytes.Buffer{} 433 bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length 434 // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens 435 // Write() will become noop, and error will be accessible via Flush(), which is called once in the end 436 437 binary.Write(bufferedWriter, binary.BigEndian, typeClientHello) 438 helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24 439 binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes) 440 binary.Write(bufferedWriter, binary.BigEndian, hello.Vers) 441 442 binary.Write(bufferedWriter, binary.BigEndian, hello.Random) 443 444 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId))) 445 binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId) 446 447 binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1)) 448 for _, suite := range hello.CipherSuites { 449 binary.Write(bufferedWriter, binary.BigEndian, suite) 450 } 451 452 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods))) 453 binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods) 454 455 if len(uconn.Extensions) > 0 { 456 binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen)) 457 for _, ext := range uconn.Extensions { 458 bufferedWriter.ReadFrom(ext) 459 } 460 } 461 462 err := bufferedWriter.Flush() 463 if err != nil { 464 return err 465 } 466 467 if helloBuffer.Len() != 4+helloLen { 468 return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) + 469 ". Got: " + strconv.Itoa(helloBuffer.Len())) 470 } 471 472 hello.Raw = helloBuffer.Bytes() 473 return nil 474} 475 476// get current state of cipher and encrypt zeros to get keystream 477func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) { 478 zeros := make([]byte, length) 479 480 if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok { 481 // AEAD.Seal() does not mutate internal state, other ciphers might 482 return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil 483 } 484 return nil, errors.New("Could not convert OutCipher to cipher.AEAD") 485} 486 487// SetTLSVers sets min and max TLS version in all appropriate places. 488// Function will use first non-zero version parsed in following order: 489// 1) Provided minTLSVers, maxTLSVers 490// 2) specExtensions may have SupportedVersionsExtension 491// 3) [default] min = TLS 1.0, max = TLS 1.2 492// 493// Error is only returned if things are in clearly undesirable state 494// to help user fix them. 495func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error { 496 if minTLSVers == 0 && maxTLSVers == 0 { 497 // if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension 498 supportedVersionsExtensionsPresent := 0 499 for _, e := range specExtensions { 500 switch ext := e.(type) { 501 case *SupportedVersionsExtension: 502 findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) { 503 // returns (minVers, maxVers) 504 minVers := uint16(0) 505 maxVers := uint16(0) 506 for _, vers := range versions { 507 if vers == GREASE_PLACEHOLDER { 508 continue 509 } 510 if maxVers < vers || maxVers == 0 { 511 maxVers = vers 512 } 513 if minVers > vers || minVers == 0 { 514 minVers = vers 515 } 516 } 517 return minVers, maxVers 518 } 519 520 supportedVersionsExtensionsPresent += 1 521 minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions) 522 if minTLSVers == 0 && maxTLSVers == 0 { 523 return fmt.Errorf("SupportedVersions extension has invalid Versions field") 524 } // else: proceed 525 } 526 } 527 switch supportedVersionsExtensionsPresent { 528 case 0: 529 // if mandatory for TLS 1.3 extension is not present, just default to 1.2 530 minTLSVers = VersionTLS10 531 maxTLSVers = VersionTLS12 532 case 1: 533 default: 534 return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions", 535 supportedVersionsExtensionsPresent) 536 } 537 } 538 539 if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 { 540 return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers) 541 } 542 543 if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 { 544 return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers) 545 } 546 547 uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) 548 uconn.config.MinVersion = minTLSVers 549 uconn.config.MaxVersion = maxTLSVers 550 551 return nil 552} 553 554func (uconn *UConn) SetUnderlyingConn(c net.Conn) { 555 uconn.Conn.conn = c 556} 557 558func (uconn *UConn) GetUnderlyingConn() net.Conn { 559 return uconn.Conn.conn 560} 561 562// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections. 563// Major Hack Alert. 564func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn { 565 tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient} 566 cs := cipherSuiteByID(cipherSuite) 567 if cs == nil { 568 return nil 569 } 570 571 // This is mostly borrowed from establishKeys() 572 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := 573 keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom, 574 cs.macLen, cs.keyLen, cs.ivLen) 575 576 var clientCipher, serverCipher interface{} 577 var clientHash, serverHash macFunction 578 if cs.cipher != nil { 579 clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */) 580 clientHash = cs.mac(version, clientMAC) 581 serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */) 582 serverHash = cs.mac(version, serverMAC) 583 } else { 584 clientCipher = cs.aead(clientKey, clientIV) 585 serverCipher = cs.aead(serverKey, serverIV) 586 } 587 588 if isClient { 589 tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash) 590 tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash) 591 } else { 592 tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash) 593 tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash) 594 } 595 596 // skip the handshake states 597 tlsConn.handshakeStatus = 1 598 tlsConn.cipherSuite = cipherSuite 599 tlsConn.haveVers = true 600 tlsConn.vers = version 601 602 // Update to the new cipher specs 603 // and consume the finished messages 604 tlsConn.in.changeCipherSpec() 605 tlsConn.out.changeCipherSpec() 606 607 tlsConn.in.incSeq() 608 tlsConn.out.incSeq() 609 610 return tlsConn 611} 612 613func makeSupportedVersions(minVers, maxVers uint16) []uint16 { 614 a := make([]uint16, maxVers-minVers+1) 615 for i := range a { 616 a[i] = maxVers - uint16(i) 617 } 618 return a 619} 620