1// Copyright 2017 Google Inc. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package tls 6 7import ( 8 "bufio" 9 "bytes" 10 "crypto/cipher" 11 "encoding/binary" 12 "errors" 13 "fmt" 14 "io" 15 "net" 16 "strconv" 17 "sync/atomic" 18) 19 20type UConn struct { 21 *Conn 22 23 Extensions []TLSExtension 24 ClientHelloID ClientHelloID 25 26 ClientHelloBuilt bool 27 HandshakeState ClientHandshakeState 28 29 // sessionID may or may not depend on ticket; nil => random 30 GetSessionID func(ticket []byte) [32]byte 31 32 greaseSeed [ssl_grease_last_index]uint16 33 34 omitSNIExtension bool 35} 36 37// UClient returns a new uTLS client, with behavior depending on clientHelloID. 38// Config CAN be nil, but make sure to eventually specify ServerName. 39func UClient(conn net.Conn, config *Config, clientHelloID ClientHelloID) *UConn { 40 if config == nil { 41 config = &Config{} 42 } 43 tlsConn := Conn{conn: conn, config: config, isClient: true} 44 handshakeState := ClientHandshakeState{C: &tlsConn, Hello: &ClientHelloMsg{}} 45 uconn := UConn{Conn: &tlsConn, ClientHelloID: clientHelloID, HandshakeState: handshakeState} 46 uconn.HandshakeState.uconn = &uconn 47 return &uconn 48} 49 50// BuildHandshakeState behavior varies based on ClientHelloID and 51// whether it was already called before. 52// If HelloGolang: 53// [only once] make default ClientHello and overwrite existing state 54// If any other mimicking ClientHelloID is used: 55// [only once] make ClientHello based on ID and overwrite existing state 56// [each call] apply uconn.Extensions config to internal crypto/tls structures 57// [each call] marshal ClientHello. 58// 59// BuildHandshakeState is automatically called before uTLS performs handshake, 60// amd should only be called explicitly to inspect/change fields of 61// default/mimicked ClientHello. 62func (uconn *UConn) BuildHandshakeState() error { 63 if uconn.ClientHelloID == HelloGolang { 64 if uconn.ClientHelloBuilt { 65 return nil 66 } 67 68 // use default Golang ClientHello. 69 hello, ecdheParams, err := uconn.makeClientHello() 70 if err != nil { 71 return err 72 } 73 74 uconn.HandshakeState.Hello = hello.getPublicPtr() 75 uconn.HandshakeState.State13.EcdheParams = ecdheParams 76 uconn.HandshakeState.C = uconn.Conn 77 } else { 78 if !uconn.ClientHelloBuilt { 79 err := uconn.applyPresetByID(uconn.ClientHelloID) 80 if err != nil { 81 return err 82 } 83 if uconn.omitSNIExtension { 84 uconn.removeSNIExtension() 85 } 86 } 87 88 err := uconn.ApplyConfig() 89 if err != nil { 90 return err 91 } 92 err = uconn.MarshalClientHello() 93 if err != nil { 94 return err 95 } 96 } 97 uconn.ClientHelloBuilt = true 98 return nil 99} 100 101// SetSessionState sets the session ticket, which may be preshared or fake. 102// If session is nil, the body of session ticket extension will be unset, 103// but the extension itself still MAY be present for mimicking purposes. 104// Session tickets to be reused - use same cache on following connections. 105func (uconn *UConn) SetSessionState(session *ClientSessionState) error { 106 uconn.HandshakeState.Session = session 107 var sessionTicket []uint8 108 if session != nil { 109 sessionTicket = session.sessionTicket 110 } 111 uconn.HandshakeState.Hello.TicketSupported = true 112 uconn.HandshakeState.Hello.SessionTicket = sessionTicket 113 114 for _, ext := range uconn.Extensions { 115 st, ok := ext.(*SessionTicketExtension) 116 if !ok { 117 continue 118 } 119 st.Session = session 120 if session != nil { 121 if len(session.SessionTicket()) > 0 { 122 if uconn.GetSessionID != nil { 123 sid := uconn.GetSessionID(session.SessionTicket()) 124 uconn.HandshakeState.Hello.SessionId = sid[:] 125 return nil 126 } 127 } 128 var sessionID [32]byte 129 _, err := io.ReadFull(uconn.config.rand(), sessionID[:]) 130 if err != nil { 131 return err 132 } 133 uconn.HandshakeState.Hello.SessionId = sessionID[:] 134 } 135 return nil 136 } 137 return nil 138} 139 140// If you want session tickets to be reused - use same cache on following connections 141func (uconn *UConn) SetSessionCache(cache ClientSessionCache) { 142 uconn.config.ClientSessionCache = cache 143 uconn.HandshakeState.Hello.TicketSupported = true 144} 145 146// SetClientRandom sets client random explicitly. 147// BuildHandshakeFirst() must be called before SetClientRandom. 148// r must to be 32 bytes long. 149func (uconn *UConn) SetClientRandom(r []byte) error { 150 if len(r) != 32 { 151 return errors.New("Incorrect client random length! Expected: 32, got: " + strconv.Itoa(len(r))) 152 } else { 153 uconn.HandshakeState.Hello.Random = make([]byte, 32) 154 copy(uconn.HandshakeState.Hello.Random, r) 155 return nil 156 } 157} 158 159func (uconn *UConn) SetSNI(sni string) { 160 hname := hostnameInSNI(sni) 161 uconn.config.ServerName = hname 162 for _, ext := range uconn.Extensions { 163 sniExt, ok := ext.(*SNIExtension) 164 if ok { 165 sniExt.ServerName = hname 166 } 167 } 168} 169 170// RemoveSNIExtension removes SNI from the list of extensions sent in ClientHello 171// It returns an error when used with HelloGolang ClientHelloID 172func (uconn *UConn) RemoveSNIExtension() error { 173 if uconn.ClientHelloID == HelloGolang { 174 return fmt.Errorf("Cannot call RemoveSNIExtension on a UConn with a HelloGolang ClientHelloID") 175 } 176 uconn.omitSNIExtension = true 177 return nil 178} 179 180func (uconn *UConn) removeSNIExtension() { 181 filteredExts := make([]TLSExtension, 0, len(uconn.Extensions)) 182 for _, e := range uconn.Extensions { 183 if _, ok := e.(*SNIExtension); !ok { 184 filteredExts = append(filteredExts, e) 185 } 186 } 187 uconn.Extensions = filteredExts 188} 189 190// Handshake runs the client handshake using given clientHandshakeState 191// Requires hs.hello, and, optionally, hs.session to be set. 192func (c *UConn) Handshake() error { 193 c.handshakeMutex.Lock() 194 defer c.handshakeMutex.Unlock() 195 196 if err := c.handshakeErr; err != nil { 197 return err 198 } 199 if c.handshakeComplete() { 200 return nil 201 } 202 203 c.in.Lock() 204 defer c.in.Unlock() 205 206 if c.isClient { 207 // [uTLS section begins] 208 err := c.BuildHandshakeState() 209 if err != nil { 210 return err 211 } 212 // [uTLS section ends] 213 214 c.handshakeErr = c.clientHandshake() 215 } else { 216 c.handshakeErr = c.serverHandshake() 217 } 218 if c.handshakeErr == nil { 219 c.handshakes++ 220 } else { 221 // If an error occurred during the hadshake try to flush the 222 // alert that might be left in the buffer. 223 c.flush() 224 } 225 226 if c.handshakeErr == nil && !c.handshakeComplete() { 227 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result") 228 } 229 230 return c.handshakeErr 231} 232 233// Copy-pasted from tls.Conn in its entirety. But c.Handshake() is now utls' one, not tls. 234// Write writes data to the connection. 235func (c *UConn) Write(b []byte) (int, error) { 236 // interlock with Close below 237 for { 238 x := atomic.LoadInt32(&c.activeCall) 239 if x&1 != 0 { 240 return 0, errClosed 241 } 242 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) { 243 defer atomic.AddInt32(&c.activeCall, -2) 244 break 245 } 246 } 247 248 if err := c.Handshake(); err != nil { 249 return 0, err 250 } 251 252 c.out.Lock() 253 defer c.out.Unlock() 254 255 if err := c.out.err; err != nil { 256 return 0, err 257 } 258 259 if !c.handshakeComplete() { 260 return 0, alertInternalError 261 } 262 263 if c.closeNotifySent { 264 return 0, errShutdown 265 } 266 267 // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext 268 // attack when using block mode ciphers due to predictable IVs. 269 // This can be prevented by splitting each Application Data 270 // record into two records, effectively randomizing the IV. 271 // 272 // https://www.openssl.org/~bodo/tls-cbc.txt 273 // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 274 // https://www.imperialviolet.org/2012/01/15/beastfollowup.html 275 276 var m int 277 if len(b) > 1 && c.vers <= VersionTLS10 { 278 if _, ok := c.out.cipher.(cipher.BlockMode); ok { 279 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1]) 280 if err != nil { 281 return n, c.out.setErrorLocked(err) 282 } 283 m, b = 1, b[1:] 284 } 285 } 286 287 n, err := c.writeRecordLocked(recordTypeApplicationData, b) 288 return n + m, c.out.setErrorLocked(err) 289} 290 291// clientHandshakeWithOneState checks that exactly one expected state is set (1.2 or 1.3) 292// and performs client TLS handshake with that state 293func (c *UConn) clientHandshake() (err error) { 294 // [uTLS section begins] 295 hello := c.HandshakeState.Hello.getPrivatePtr() 296 defer func() { c.HandshakeState.Hello = hello.getPublicPtr() }() 297 298 sessionIsAlreadySet := c.HandshakeState.Session != nil 299 300 // after this point exactly 1 out of 2 HandshakeState pointers is non-nil, 301 // useTLS13 variable tells which pointer 302 // [uTLS section ends] 303 304 if c.config == nil { 305 c.config = defaultConfig() 306 } 307 308 // This may be a renegotiation handshake, in which case some fields 309 // need to be reset. 310 c.didResume = false 311 312 // [uTLS section begins] 313 // don't make new ClientHello, use hs.hello 314 // preserve the checks from beginning and end of makeClientHello() 315 if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { 316 return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") 317 } 318 319 nextProtosLength := 0 320 for _, proto := range c.config.NextProtos { 321 if l := len(proto); l == 0 || l > 255 { 322 return errors.New("tls: invalid NextProtos value") 323 } else { 324 nextProtosLength += 1 + l 325 } 326 } 327 328 if nextProtosLength > 0xffff { 329 return errors.New("tls: NextProtos values too large") 330 } 331 332 if c.handshakes > 0 { 333 hello.secureRenegotiation = c.clientFinished[:] 334 } 335 // [uTLS section ends] 336 337 cacheKey, session, earlySecret, binderKey := c.loadSession(hello) 338 if cacheKey != "" && session != nil { 339 defer func() { 340 // If we got a handshake failure when resuming a session, throw away 341 // the session ticket. See RFC 5077, Section 3.2. 342 // 343 // RFC 8446 makes no mention of dropping tickets on failure, but it 344 // does require servers to abort on invalid binders, so we need to 345 // delete tickets to recover from a corrupted PSK. 346 if err != nil { 347 c.config.ClientSessionCache.Put(cacheKey, nil) 348 } 349 }() 350 } 351 352 if !sessionIsAlreadySet { // uTLS: do not overwrite already set session 353 err = c.SetSessionState(session) 354 if err != nil { 355 return 356 } 357 } 358 359 if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { 360 return err 361 } 362 363 msg, err := c.readHandshake() 364 if err != nil { 365 return err 366 } 367 368 serverHello, ok := msg.(*serverHelloMsg) 369 if !ok { 370 c.sendAlert(alertUnexpectedMessage) 371 return unexpectedMessageError(serverHello, msg) 372 } 373 374 if err := c.pickTLSVersion(serverHello); err != nil { 375 return err 376 } 377 378 // uTLS: do not create new handshakeState, use existing one 379 if c.vers == VersionTLS13 { 380 hs13 := c.HandshakeState.toPrivate13() 381 hs13.serverHello = serverHello 382 hs13.hello = hello 383 if !sessionIsAlreadySet { 384 hs13.earlySecret = earlySecret 385 hs13.binderKey = binderKey 386 } 387 // In TLS 1.3, session tickets are delivered after the handshake. 388 err = hs13.handshake() 389 if handshakeState := hs13.toPublic13(); handshakeState != nil { 390 c.HandshakeState = *handshakeState 391 } 392 return err 393 } 394 395 hs12 := c.HandshakeState.toPrivate12() 396 hs12.serverHello = serverHello 397 hs12.hello = hello 398 err = hs12.handshake() 399 if handshakeState := hs12.toPublic12(); handshakeState != nil { 400 c.HandshakeState = *handshakeState 401 } 402 if err != nil { 403 return err 404 } 405 406 // If we had a successful handshake and hs.session is different from 407 // the one already cached - cache a new one. 408 if cacheKey != "" && hs12.session != nil && session != hs12.session { 409 c.config.ClientSessionCache.Put(cacheKey, hs12.session) 410 } 411 return nil 412} 413 414func (uconn *UConn) ApplyConfig() error { 415 for _, ext := range uconn.Extensions { 416 err := ext.writeToUConn(uconn) 417 if err != nil { 418 return err 419 } 420 } 421 return nil 422} 423 424func (uconn *UConn) MarshalClientHello() error { 425 hello := uconn.HandshakeState.Hello 426 headerLength := 2 + 32 + 1 + len(hello.SessionId) + 427 2 + len(hello.CipherSuites)*2 + 428 1 + len(hello.CompressionMethods) 429 430 extensionsLen := 0 431 var paddingExt *UtlsPaddingExtension 432 for _, ext := range uconn.Extensions { 433 if pe, ok := ext.(*UtlsPaddingExtension); !ok { 434 // If not padding - just add length of extension to total length 435 extensionsLen += ext.Len() 436 } else { 437 // If padding - process it later 438 if paddingExt == nil { 439 paddingExt = pe 440 } else { 441 return errors.New("Multiple padding extensions!") 442 } 443 } 444 } 445 446 if paddingExt != nil { 447 // determine padding extension presence and length 448 paddingExt.Update(headerLength + 4 + extensionsLen + 2) 449 extensionsLen += paddingExt.Len() 450 } 451 452 helloLen := headerLength 453 if len(uconn.Extensions) > 0 { 454 helloLen += 2 + extensionsLen // 2 bytes for extensions' length 455 } 456 457 helloBuffer := bytes.Buffer{} 458 bufferedWriter := bufio.NewWriterSize(&helloBuffer, helloLen+4) // 1 byte for tls record type, 3 for length 459 // We use buffered Writer to avoid checking write errors after every Write(): whenever first error happens 460 // Write() will become noop, and error will be accessible via Flush(), which is called once in the end 461 462 binary.Write(bufferedWriter, binary.BigEndian, typeClientHello) 463 helloLenBytes := []byte{byte(helloLen >> 16), byte(helloLen >> 8), byte(helloLen)} // poor man's uint24 464 binary.Write(bufferedWriter, binary.BigEndian, helloLenBytes) 465 binary.Write(bufferedWriter, binary.BigEndian, hello.Vers) 466 467 binary.Write(bufferedWriter, binary.BigEndian, hello.Random) 468 469 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.SessionId))) 470 binary.Write(bufferedWriter, binary.BigEndian, hello.SessionId) 471 472 binary.Write(bufferedWriter, binary.BigEndian, uint16(len(hello.CipherSuites)<<1)) 473 for _, suite := range hello.CipherSuites { 474 binary.Write(bufferedWriter, binary.BigEndian, suite) 475 } 476 477 binary.Write(bufferedWriter, binary.BigEndian, uint8(len(hello.CompressionMethods))) 478 binary.Write(bufferedWriter, binary.BigEndian, hello.CompressionMethods) 479 480 if len(uconn.Extensions) > 0 { 481 binary.Write(bufferedWriter, binary.BigEndian, uint16(extensionsLen)) 482 for _, ext := range uconn.Extensions { 483 bufferedWriter.ReadFrom(ext) 484 } 485 } 486 487 err := bufferedWriter.Flush() 488 if err != nil { 489 return err 490 } 491 492 if helloBuffer.Len() != 4+helloLen { 493 return errors.New("utls: unexpected ClientHello length. Expected: " + strconv.Itoa(4+helloLen) + 494 ". Got: " + strconv.Itoa(helloBuffer.Len())) 495 } 496 497 hello.Raw = helloBuffer.Bytes() 498 return nil 499} 500 501// get current state of cipher and encrypt zeros to get keystream 502func (uconn *UConn) GetOutKeystream(length int) ([]byte, error) { 503 zeros := make([]byte, length) 504 505 if outCipher, ok := uconn.out.cipher.(cipher.AEAD); ok { 506 // AEAD.Seal() does not mutate internal state, other ciphers might 507 return outCipher.Seal(nil, uconn.out.seq[:], zeros, nil), nil 508 } 509 return nil, errors.New("Could not convert OutCipher to cipher.AEAD") 510} 511 512// SetTLSVers sets min and max TLS version in all appropriate places. 513// Function will use first non-zero version parsed in following order: 514// 1) Provided minTLSVers, maxTLSVers 515// 2) specExtensions may have SupportedVersionsExtension 516// 3) [default] min = TLS 1.0, max = TLS 1.2 517// 518// Error is only returned if things are in clearly undesirable state 519// to help user fix them. 520func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []TLSExtension) error { 521 if minTLSVers == 0 && maxTLSVers == 0 { 522 // if version is not set explicitly in the ClientHelloSpec, check the SupportedVersions extension 523 supportedVersionsExtensionsPresent := 0 524 for _, e := range specExtensions { 525 switch ext := e.(type) { 526 case *SupportedVersionsExtension: 527 findVersionsInSupportedVersionsExtensions := func(versions []uint16) (uint16, uint16) { 528 // returns (minVers, maxVers) 529 minVers := uint16(0) 530 maxVers := uint16(0) 531 for _, vers := range versions { 532 if vers == GREASE_PLACEHOLDER { 533 continue 534 } 535 if maxVers < vers || maxVers == 0 { 536 maxVers = vers 537 } 538 if minVers > vers || minVers == 0 { 539 minVers = vers 540 } 541 } 542 return minVers, maxVers 543 } 544 545 supportedVersionsExtensionsPresent += 1 546 minTLSVers, maxTLSVers = findVersionsInSupportedVersionsExtensions(ext.Versions) 547 if minTLSVers == 0 && maxTLSVers == 0 { 548 return fmt.Errorf("SupportedVersions extension has invalid Versions field") 549 } // else: proceed 550 } 551 } 552 switch supportedVersionsExtensionsPresent { 553 case 0: 554 // if mandatory for TLS 1.3 extension is not present, just default to 1.2 555 minTLSVers = VersionTLS10 556 maxTLSVers = VersionTLS12 557 case 1: 558 default: 559 return fmt.Errorf("uconn.Extensions contains %v separate SupportedVersions extensions", 560 supportedVersionsExtensionsPresent) 561 } 562 } 563 564 if minTLSVers < VersionTLS10 || minTLSVers > VersionTLS12 { 565 return fmt.Errorf("uTLS does not support 0x%X as min version", minTLSVers) 566 } 567 568 if maxTLSVers < VersionTLS10 || maxTLSVers > VersionTLS13 { 569 return fmt.Errorf("uTLS does not support 0x%X as max version", maxTLSVers) 570 } 571 572 uconn.HandshakeState.Hello.SupportedVersions = makeSupportedVersions(minTLSVers, maxTLSVers) 573 uconn.config.MinVersion = minTLSVers 574 uconn.config.MaxVersion = maxTLSVers 575 576 return nil 577} 578 579func (uconn *UConn) SetUnderlyingConn(c net.Conn) { 580 uconn.Conn.conn = c 581} 582 583func (uconn *UConn) GetUnderlyingConn() net.Conn { 584 return uconn.Conn.conn 585} 586 587// MakeConnWithCompleteHandshake allows to forge both server and client side TLS connections. 588// Major Hack Alert. 589func MakeConnWithCompleteHandshake(tcpConn net.Conn, version uint16, cipherSuite uint16, masterSecret []byte, clientRandom []byte, serverRandom []byte, isClient bool) *Conn { 590 tlsConn := &Conn{conn: tcpConn, config: &Config{}, isClient: isClient} 591 cs := cipherSuiteByID(cipherSuite) 592 if cs == nil { 593 return nil 594 } 595 596 // This is mostly borrowed from establishKeys() 597 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := 598 keysFromMasterSecret(version, cs, masterSecret, clientRandom, serverRandom, 599 cs.macLen, cs.keyLen, cs.ivLen) 600 601 var clientCipher, serverCipher interface{} 602 var clientHash, serverHash macFunction 603 if cs.cipher != nil { 604 clientCipher = cs.cipher(clientKey, clientIV, true /* for reading */) 605 clientHash = cs.mac(version, clientMAC) 606 serverCipher = cs.cipher(serverKey, serverIV, false /* not for reading */) 607 serverHash = cs.mac(version, serverMAC) 608 } else { 609 clientCipher = cs.aead(clientKey, clientIV) 610 serverCipher = cs.aead(serverKey, serverIV) 611 } 612 613 if isClient { 614 tlsConn.in.prepareCipherSpec(version, serverCipher, serverHash) 615 tlsConn.out.prepareCipherSpec(version, clientCipher, clientHash) 616 } else { 617 tlsConn.in.prepareCipherSpec(version, clientCipher, clientHash) 618 tlsConn.out.prepareCipherSpec(version, serverCipher, serverHash) 619 } 620 621 // skip the handshake states 622 tlsConn.handshakeStatus = 1 623 tlsConn.cipherSuite = cipherSuite 624 tlsConn.haveVers = true 625 tlsConn.vers = version 626 627 // Update to the new cipher specs 628 // and consume the finished messages 629 tlsConn.in.changeCipherSpec() 630 tlsConn.out.changeCipherSpec() 631 632 tlsConn.in.incSeq() 633 tlsConn.out.incSeq() 634 635 return tlsConn 636} 637 638func makeSupportedVersions(minVers, maxVers uint16) []uint16 { 639 a := make([]uint16, maxVers-minVers+1) 640 for i := range a { 641 a[i] = maxVers - uint16(i) 642 } 643 return a 644} 645