1package qtls 2 3import ( 4 "bytes" 5 "fmt" 6 "net" 7 "testing" 8 "time" 9) 10 11type recordLayer struct { 12 in <-chan []byte 13 out chan<- []byte 14 15 alertSent alert 16} 17 18func (r *recordLayer) SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) { 19} 20func (r *recordLayer) SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) { 21} 22func (r *recordLayer) ReadHandshakeMessage() ([]byte, error) { return <-r.in, nil } 23func (r *recordLayer) WriteRecord(b []byte) (int, error) { r.out <- b; return len(b), nil } 24func (r *recordLayer) SendAlert(a uint8) { r.alertSent = alert(a) } 25 26type exportedKey struct { 27 typ string // "read" or "write" 28 encLevel EncryptionLevel 29 suite *CipherSuiteTLS13 30 trafficSecret []byte 31} 32 33func compareExportedKeys(t *testing.T, k1, k2 *exportedKey) { 34 if k1.encLevel != k2.encLevel || k1.suite.ID != k2.suite.ID || !bytes.Equal(k1.trafficSecret, k2.trafficSecret) { 35 t.Fatal("mismatching keys") 36 } 37} 38 39type recordLayerWithKeys struct { 40 in <-chan []byte 41 out chan<- interface{} 42} 43 44func (r *recordLayerWithKeys) SetReadKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) { 45 r.out <- &exportedKey{typ: "read", encLevel: encLevel, suite: suite, trafficSecret: trafficSecret} 46} 47func (r *recordLayerWithKeys) SetWriteKey(encLevel EncryptionLevel, suite *CipherSuiteTLS13, trafficSecret []byte) { 48 r.out <- &exportedKey{typ: "write", encLevel: encLevel, suite: suite, trafficSecret: trafficSecret} 49} 50func (r *recordLayerWithKeys) ReadHandshakeMessage() ([]byte, error) { return <-r.in, nil } 51func (r *recordLayerWithKeys) WriteRecord(b []byte) (int, error) { r.out <- b; return len(b), nil } 52func (r *recordLayerWithKeys) SendAlert(uint8) {} 53 54type unusedConn struct { 55 remoteAddr net.Addr 56} 57 58var _ net.Conn = &unusedConn{} 59 60func (unusedConn) Read([]byte) (int, error) { panic("unexpected call to Read()") } 61func (unusedConn) Write([]byte) (int, error) { panic("unexpected call to Write()") } 62func (unusedConn) Close() error { return nil } 63func (unusedConn) LocalAddr() net.Addr { return &net.TCPAddr{} } 64func (c *unusedConn) RemoteAddr() net.Addr { return c.remoteAddr } 65func (unusedConn) SetDeadline(time.Time) error { return nil } 66func (unusedConn) SetReadDeadline(time.Time) error { return nil } 67func (unusedConn) SetWriteDeadline(time.Time) error { return nil } 68 69func TestAlternativeRecordLayer(t *testing.T) { 70 sIn := make(chan []byte, 10) 71 sOut := make(chan interface{}, 10) 72 defer close(sOut) 73 cIn := make(chan []byte, 10) 74 cOut := make(chan interface{}, 10) 75 defer close(cOut) 76 77 serverEvents := make(chan interface{}, 100) 78 go func() { 79 for { 80 c, ok := <-sOut 81 if !ok { 82 return 83 } 84 serverEvents <- c 85 if b, ok := c.([]byte); ok { 86 cIn <- b 87 } 88 } 89 }() 90 91 clientEvents := make(chan interface{}, 100) 92 go func() { 93 for { 94 c, ok := <-cOut 95 if !ok { 96 return 97 } 98 clientEvents <- c 99 if b, ok := c.([]byte); ok { 100 sIn <- b 101 } 102 } 103 }() 104 105 errChan := make(chan error) 106 go func() { 107 extraConf := &ExtraConfig{ 108 AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut}, 109 } 110 tlsConn := Server(&unusedConn{}, testConfig, extraConf) 111 defer tlsConn.Close() 112 errChan <- tlsConn.Handshake() 113 }() 114 115 extraConf := &ExtraConfig{ 116 AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut}, 117 } 118 tlsConn := Client(&unusedConn{}, testConfig, extraConf) 119 defer tlsConn.Close() 120 if err := tlsConn.Handshake(); err != nil { 121 t.Fatalf("Handshake failed: %s", err) 122 } 123 124 // Handshakes completed. Now check that events were received in the correct order. 125 var clientHandshakeReadKey, clientHandshakeWriteKey *exportedKey 126 var clientApplicationReadKey, clientApplicationWriteKey *exportedKey 127 for i := 0; i <= 5; i++ { 128 ev := <-clientEvents 129 switch i { 130 case 0: 131 if ev.([]byte)[0] != typeClientHello { 132 t.Fatalf("expected ClientHello") 133 } 134 case 1: 135 keyEv := ev.(*exportedKey) 136 if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake { 137 t.Fatalf("expected the handshake write key") 138 } 139 clientHandshakeWriteKey = keyEv 140 case 2: 141 keyEv := ev.(*exportedKey) 142 if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake { 143 t.Fatalf("expected the handshake read key") 144 } 145 clientHandshakeReadKey = keyEv 146 case 3: 147 keyEv := ev.(*exportedKey) 148 if keyEv.typ != "read" || keyEv.encLevel != EncryptionApplication { 149 t.Fatalf("expected the application read key") 150 } 151 clientApplicationReadKey = keyEv 152 case 4: 153 if ev.([]byte)[0] != typeFinished { 154 t.Fatalf("expected Finished") 155 } 156 case 5: 157 keyEv := ev.(*exportedKey) 158 if keyEv.typ != "write" || keyEv.encLevel != EncryptionApplication { 159 t.Fatalf("expected the application write key") 160 } 161 clientApplicationWriteKey = keyEv 162 } 163 } 164 if len(clientEvents) > 0 { 165 t.Fatal("didn't expect any more client events") 166 } 167 168 for i := 0; i <= 8; i++ { 169 ev := <-serverEvents 170 switch i { 171 case 0: 172 if ev.([]byte)[0] != typeServerHello { 173 t.Fatalf("expected ServerHello") 174 } 175 case 1: 176 keyEv := ev.(*exportedKey) 177 if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake { 178 t.Fatalf("expected the handshake read key") 179 } 180 compareExportedKeys(t, clientHandshakeWriteKey, keyEv) 181 case 2: 182 keyEv := ev.(*exportedKey) 183 if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake { 184 t.Fatalf("expected the handshake write key") 185 } 186 compareExportedKeys(t, clientHandshakeReadKey, keyEv) 187 case 3: 188 if ev.([]byte)[0] != typeEncryptedExtensions { 189 t.Fatalf("expected EncryptedExtensions") 190 } 191 case 4: 192 if ev.([]byte)[0] != typeCertificate { 193 t.Fatalf("expected Certificate") 194 } 195 case 5: 196 if ev.([]byte)[0] != typeCertificateVerify { 197 t.Fatalf("expected CertificateVerify") 198 } 199 case 6: 200 if ev.([]byte)[0] != typeFinished { 201 t.Fatalf("expected Finished") 202 } 203 case 7: 204 keyEv := ev.(*exportedKey) 205 if keyEv.typ != "write" || keyEv.encLevel != EncryptionApplication { 206 t.Fatalf("expected the application write key") 207 } 208 compareExportedKeys(t, clientApplicationReadKey, keyEv) 209 case 8: 210 keyEv := ev.(*exportedKey) 211 if keyEv.typ != "read" || keyEv.encLevel != EncryptionApplication { 212 t.Fatalf("expected the application read key") 213 } 214 compareExportedKeys(t, clientApplicationWriteKey, keyEv) 215 } 216 } 217 if len(serverEvents) > 0 { 218 t.Fatal("didn't expect any more server events") 219 } 220} 221 222func TestErrorOnOldTLSVersions(t *testing.T) { 223 sIn := make(chan []byte, 10) 224 cIn := make(chan []byte, 10) 225 cOut := make(chan []byte, 10) 226 227 go func() { 228 for { 229 b, ok := <-cOut 230 if !ok { 231 return 232 } 233 if b[0] == typeClientHello { 234 m := new(clientHelloMsg) 235 if !m.unmarshal(b) { 236 panic("unmarshal failed") 237 } 238 m.raw = nil // need to reset, so marshal() actually marshals the changes 239 m.supportedVersions = []uint16{VersionTLS11, VersionTLS13} 240 b = m.marshal() 241 } 242 sIn <- b 243 } 244 }() 245 246 done := make(chan struct{}) 247 go func() { 248 defer close(done) 249 extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: cIn, out: cOut}} 250 Client(&unusedConn{}, testConfig, extraConf).Handshake() 251 }() 252 253 serverRecordLayer := &recordLayer{in: sIn, out: cIn} 254 extraConf := &ExtraConfig{AlternativeRecordLayer: serverRecordLayer} 255 tlsConn := Server(&unusedConn{}, testConfig, extraConf) 256 defer tlsConn.Close() 257 err := tlsConn.Handshake() 258 if err == nil || err.Error() != "tls: client offered old TLS version 0x302" { 259 t.Fatal("expected the server to error when the client offers old versions") 260 } 261 if serverRecordLayer.alertSent != alertProtocolVersion { 262 t.Fatal("expected a protocol version alert to be sent") 263 } 264 265 cIn <- []byte{'f'} 266 <-done 267} 268 269func TestRejectConfigWithOldMaxVersion(t *testing.T) { 270 t.Run("for the client", func(t *testing.T) { 271 config := testConfig.Clone() 272 config.MaxVersion = VersionTLS12 273 tlsConn := Client(&unusedConn{}, config, &ExtraConfig{AlternativeRecordLayer: &recordLayer{}}) 274 err := tlsConn.Handshake() 275 if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" { 276 t.Errorf("expected the handshake to fail") 277 } 278 }) 279 280 t.Run("for the server", func(t *testing.T) { 281 in := make(chan []byte, 10) 282 out := make(chan []byte, 10) 283 284 done := make(chan struct{}) 285 go func() { 286 defer close(done) 287 Client( 288 &unusedConn{}, 289 testConfig, 290 &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: in, out: out}}, 291 ).Handshake() 292 }() 293 294 config := testConfig.Clone() 295 config.MaxVersion = VersionTLS12 296 serverRecordLayer := &recordLayer{in: out, out: in} 297 err := Server( 298 &unusedConn{}, 299 config, 300 &ExtraConfig{AlternativeRecordLayer: serverRecordLayer}, 301 ).Handshake() 302 if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" { 303 t.Errorf("expected the handshake to fail") 304 } 305 if serverRecordLayer.alertSent != alertInternalError { 306 t.Fatal("expected an internal error alert to be sent") 307 } 308 }) 309 310 t.Run("for the server (using GetConfigForClient)", func(t *testing.T) { 311 in := make(chan []byte, 10) 312 out := make(chan []byte, 10) 313 314 done := make(chan struct{}) 315 go func() { 316 defer close(done) 317 Client( 318 &unusedConn{}, 319 testConfig, 320 &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: in, out: out}}, 321 ).Handshake() 322 }() 323 324 config := testConfig.Clone() 325 config.GetConfigForClient = func(*ClientHelloInfo) (*Config, error) { 326 conf := testConfig.Clone() 327 conf.MaxVersion = VersionTLS12 328 return conf, nil 329 } 330 serverRecordLayer := &recordLayer{in: out, out: in} 331 err := Server( 332 &unusedConn{}, 333 config, 334 &ExtraConfig{AlternativeRecordLayer: serverRecordLayer}, 335 ).Handshake() 336 if err == nil || err.Error() != "tls: MaxVersion prevents QUIC from using TLS 1.3" { 337 t.Errorf("expected the handshake to fail") 338 } 339 if serverRecordLayer.alertSent != alertInternalError { 340 t.Fatal("expected an internal error alert to be sent") 341 } 342 }) 343} 344 345func TestForbiddenZeroRTT(t *testing.T) { 346 // run the first handshake to get a session ticket 347 clientConn, serverConn := localPipe(t) 348 errChan := make(chan error, 1) 349 go func() { 350 tlsConn := Server(serverConn, testConfig.Clone(), nil) 351 defer tlsConn.Close() 352 err := tlsConn.Handshake() 353 errChan <- err 354 if err != nil { 355 return 356 } 357 tlsConn.Write([]byte{0}) 358 }() 359 360 clientConfig := testConfig.Clone() 361 clientConfig.ClientSessionCache = NewLRUClientSessionCache(10) 362 tlsConn := Client(clientConn, clientConfig, nil) 363 if err := tlsConn.Handshake(); err != nil { 364 t.Fatalf("first handshake failed: %s", err) 365 } 366 tlsConn.Read([]byte{0}) // make sure to read the session ticket 367 tlsConn.Close() 368 if err := <-errChan; err != nil { 369 t.Fatalf("first handshake failed: %s", err) 370 } 371 372 sIn := make(chan []byte, 10) 373 cIn := make(chan []byte, 10) 374 cOut := make(chan []byte, 10) 375 376 go func() { 377 for { 378 b, ok := <-cOut 379 if !ok { 380 return 381 } 382 if b[0] == typeClientHello { 383 msg := &clientHelloMsg{} 384 if ok := msg.unmarshal(b); !ok { 385 panic("unmarshaling failed") 386 } 387 msg.earlyData = true 388 msg.raw = nil 389 b = msg.marshal() 390 } 391 sIn <- b 392 } 393 }() 394 395 done := make(chan struct{}) 396 go func() { 397 defer close(done) 398 extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: cIn, out: cOut}} 399 Client(&unusedConn{remoteAddr: clientConn.RemoteAddr()}, clientConfig, extraConf).Handshake() 400 }() 401 402 config := testConfig.Clone() 403 config.MinVersion = VersionTLS13 404 serverRecordLayer := &recordLayer{in: sIn, out: cIn} 405 extraConf := &ExtraConfig{AlternativeRecordLayer: serverRecordLayer} 406 tlsConn = Server(&unusedConn{}, config, extraConf) 407 err := tlsConn.Handshake() 408 if err == nil { 409 t.Fatal("expected handshake to fail") 410 } 411 if err.Error() != "tls: client sent unexpected early data" { 412 t.Fatalf("expected early data error") 413 } 414 if serverRecordLayer.alertSent != alertUnsupportedExtension { 415 t.Fatal("expected an unsupported extension alert to be sent") 416 } 417 cIn <- []byte{0} // make the client handshake error 418 <-done 419} 420 421func TestZeroRTTKeys(t *testing.T) { 422 // run the first handshake to get a session ticket 423 clientConn, serverConn := localPipe(t) 424 errChan := make(chan error, 1) 425 go func() { 426 extraConf := &ExtraConfig{MaxEarlyData: 1000} 427 tlsConn := Server(serverConn, testConfig, extraConf) 428 defer tlsConn.Close() 429 err := tlsConn.Handshake() 430 errChan <- err 431 if err != nil { 432 return 433 } 434 tlsConn.Write([]byte{0}) 435 }() 436 437 clientConfig := testConfig.Clone() 438 clientConfig.ClientSessionCache = NewLRUClientSessionCache(10) 439 tlsConn := Client(clientConn, clientConfig, nil) 440 if err := tlsConn.Handshake(); err != nil { 441 t.Fatalf("first handshake failed: %s", err) 442 } 443 tlsConn.Read([]byte{0}) // make sure to read the session ticket 444 tlsConn.Close() 445 if err := <-errChan; err != nil { 446 t.Fatalf("first handshake failed: %s", err) 447 } 448 449 sIn := make(chan []byte, 10) 450 sOut := make(chan interface{}, 10) 451 defer close(sOut) 452 cIn := make(chan []byte, 10) 453 cOut := make(chan interface{}, 10) 454 defer close(cOut) 455 456 var serverEarlyData bool 457 var serverExportedKey *exportedKey 458 go func() { 459 for { 460 c, ok := <-sOut 461 if !ok { 462 return 463 } 464 if b, ok := c.([]byte); ok { 465 if b[0] == typeEncryptedExtensions { 466 var msg encryptedExtensionsMsg 467 if ok := msg.unmarshal(b); !ok { 468 panic("failed to unmarshal EncryptedExtensions") 469 } 470 serverEarlyData = msg.earlyData 471 } 472 cIn <- b 473 } 474 if k, ok := c.(*exportedKey); ok && k.encLevel == Encryption0RTT { 475 serverExportedKey = k 476 } 477 } 478 }() 479 480 var clientEarlyData bool 481 var clientExportedKey *exportedKey 482 go func() { 483 for { 484 c, ok := <-cOut 485 if !ok { 486 return 487 } 488 if b, ok := c.([]byte); ok { 489 if b[0] == typeClientHello { 490 var msg clientHelloMsg 491 if ok := msg.unmarshal(b); !ok { 492 panic("failed to unmarshal ClientHello") 493 } 494 clientEarlyData = msg.earlyData 495 } 496 sIn <- b 497 } 498 if k, ok := c.(*exportedKey); ok && k.encLevel == Encryption0RTT { 499 clientExportedKey = k 500 } 501 } 502 }() 503 504 errChan = make(chan error) 505 go func() { 506 extraConf := &ExtraConfig{ 507 AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut}, 508 MaxEarlyData: 1, 509 Accept0RTT: func([]byte) bool { return true }, 510 } 511 tlsConn := Server(&unusedConn{}, testConfig, extraConf) 512 defer tlsConn.Close() 513 errChan <- tlsConn.Handshake() 514 }() 515 516 extraConf := &ExtraConfig{ 517 AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut}, 518 Enable0RTT: true, 519 } 520 tlsConn = Client(&unusedConn{remoteAddr: clientConn.RemoteAddr()}, clientConfig, extraConf) 521 defer tlsConn.Close() 522 if err := tlsConn.Handshake(); err != nil { 523 t.Fatalf("Handshake failed: %s", err) 524 } 525 if err := <-errChan; err != nil { 526 t.Fatalf("Handshake failed: %s", err) 527 } 528 529 if !clientEarlyData { 530 t.Fatal("expected the client to offer early data") 531 } 532 if !serverEarlyData { 533 t.Fatal("expected the server to offer early data") 534 } 535 compareExportedKeys(t, clientExportedKey, serverExportedKey) 536} 537 538func TestEncodeIntoSessionTicket(t *testing.T) { 539 raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} 540 sIn := make(chan []byte, 10) 541 sOut := make(chan []byte, 10) 542 543 // do a first handshake and encode a "foobar" into the session ticket 544 errChan := make(chan error, 1) 545 stChan := make(chan []byte, 1) 546 go func() { 547 extraConf := &ExtraConfig{ 548 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 549 MaxEarlyData: 1, 550 } 551 server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf) 552 defer server.Close() 553 err := server.Handshake() 554 if err != nil { 555 errChan <- err 556 return 557 } 558 st, err := server.GetSessionTicket([]byte("foobar")) 559 if err != nil { 560 errChan <- err 561 return 562 } 563 stChan <- st 564 errChan <- nil 565 }() 566 567 clientConf := testConfig.Clone() 568 clientConf.ClientSessionCache = NewLRUClientSessionCache(10) 569 extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}} 570 client := Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf) 571 if err := client.Handshake(); err != nil { 572 t.Fatalf("first handshake failed %s", err) 573 } 574 if err := <-errChan; err != nil { 575 t.Fatalf("first handshake failed %s", err) 576 } 577 sOut <- <-stChan 578 if err := client.HandlePostHandshakeMessage(); err != nil { 579 t.Fatalf("handling the session ticket failed: %s", err) 580 } 581 client.Close() 582 583 dataChan := make(chan []byte, 1) 584 errChan = make(chan error, 1) 585 go func() { 586 extraConf := &ExtraConfig{ 587 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 588 MaxEarlyData: 1, 589 Accept0RTT: func(data []byte) bool { 590 dataChan <- data 591 return true 592 }, 593 } 594 server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf) 595 defer server.Close() 596 errChan <- server.Handshake() 597 }() 598 599 extraConf2 := extraConf.Clone() 600 extraConf2.Enable0RTT = true 601 client = Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf2) 602 if err := client.Handshake(); err != nil { 603 t.Fatalf("second handshake failed %s", err) 604 } 605 defer client.Close() 606 if err := <-errChan; err != nil { 607 t.Fatalf("second handshake failed %s", err) 608 } 609 if len(dataChan) != 1 { 610 t.Fatal("expected to receive application data") 611 } 612 if data := <-dataChan; !bytes.Equal(data, []byte("foobar")) { 613 t.Fatalf("expected to receive a foobar, got %s", string(data)) 614 } 615} 616 617func TestZeroRTTRejection(t *testing.T) { 618 for _, doReject := range []bool{true, false} { 619 t.Run(fmt.Sprintf("doing reject: %t", doReject), func(t *testing.T) { 620 raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} 621 sIn := make(chan []byte, 10) 622 sOut := make(chan []byte, 10) 623 624 // do a first handshake and encode a "foobar" into the session ticket 625 errChan := make(chan error, 1) 626 go func() { 627 extraConf := &ExtraConfig{ 628 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 629 MaxEarlyData: 1, 630 } 631 server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf) 632 defer server.Close() 633 err := server.Handshake() 634 if err != nil { 635 errChan <- err 636 return 637 } 638 st, err := server.GetSessionTicket(nil) 639 if err != nil { 640 errChan <- err 641 return 642 } 643 sOut <- st 644 errChan <- nil 645 }() 646 647 conf := testConfig.Clone() 648 conf.ClientSessionCache = NewLRUClientSessionCache(10) 649 extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}} 650 client := Client(&unusedConn{remoteAddr: raddr}, conf, extraConf) 651 if err := client.Handshake(); err != nil { 652 t.Fatalf("first handshake failed %s", err) 653 } 654 if err := <-errChan; err != nil { 655 t.Fatalf("first handshake failed %s", err) 656 } 657 if err := client.HandlePostHandshakeMessage(); err != nil { 658 t.Fatalf("handling the session ticket failed: %s", err) 659 } 660 client.Close() 661 662 // now dial the second connection 663 errChan = make(chan error, 1) 664 connStateChan := make(chan ConnectionStateWith0RTT, 1) 665 go func() { 666 extraConf := &ExtraConfig{ 667 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 668 MaxEarlyData: 1, 669 Accept0RTT: func(data []byte) bool { return !doReject }, 670 } 671 server := Server(&unusedConn{remoteAddr: raddr}, testConfig, extraConf) 672 defer server.Close() 673 errChan <- server.Handshake() 674 connStateChan <- server.ConnectionStateWith0RTT() 675 }() 676 677 extraConf2 := extraConf.Clone() 678 extraConf2.Enable0RTT = true 679 var rejected bool 680 extraConf2.Rejected0RTT = func() { rejected = true } 681 client = Client(&unusedConn{remoteAddr: raddr}, conf, extraConf2) 682 if err := client.Handshake(); err != nil { 683 t.Fatalf("second handshake failed %s", err) 684 } 685 defer client.Close() 686 if err := <-errChan; err != nil { 687 t.Fatalf("second handshake failed %s", err) 688 } 689 if rejected != doReject { 690 t.Fatal("wrong rejection") 691 } 692 if client.ConnectionStateWith0RTT().Used0RTT == doReject { 693 t.Fatal("wrong connection state on the client") 694 } 695 if (<-connStateChan).Used0RTT == doReject { 696 t.Fatal("wrong connection state on the server") 697 } 698 }) 699 } 700} 701 702func TestZeroRTTALPN(t *testing.T) { 703 run := func(t *testing.T, proto1, proto2 string, expectReject bool) { 704 raddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} 705 sIn := make(chan []byte, 10) 706 sOut := make(chan []byte, 10) 707 708 // do a first handshake and encode a "foobar" into the session ticket 709 errChan := make(chan error, 1) 710 go func() { 711 serverConf := testConfig.Clone() 712 serverConf.NextProtos = []string{proto1} 713 extraConf := &ExtraConfig{ 714 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 715 MaxEarlyData: 1, 716 } 717 server := Server(&unusedConn{remoteAddr: raddr}, serverConf, extraConf) 718 defer server.Close() 719 err := server.Handshake() 720 if err != nil { 721 errChan <- err 722 return 723 } 724 st, err := server.GetSessionTicket(nil) 725 if err != nil { 726 errChan <- err 727 return 728 } 729 sOut <- st 730 errChan <- nil 731 }() 732 733 clientConf := testConfig.Clone() 734 clientConf.NextProtos = []string{proto1} 735 clientConf.ClientSessionCache = NewLRUClientSessionCache(10) 736 extraConf := &ExtraConfig{AlternativeRecordLayer: &recordLayer{in: sOut, out: sIn}} 737 client := Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf) 738 if err := client.Handshake(); err != nil { 739 t.Fatalf("first handshake failed %s", err) 740 } 741 if err := <-errChan; err != nil { 742 t.Fatalf("first handshake failed %s", err) 743 } 744 if err := client.HandlePostHandshakeMessage(); err != nil { 745 t.Fatalf("handling the session ticket failed: %s", err) 746 } 747 client.Close() 748 749 // now dial the second connection 750 errChan = make(chan error, 1) 751 connStateChan := make(chan ConnectionStateWith0RTT, 1) 752 go func() { 753 serverConf := testConfig.Clone() 754 serverConf.NextProtos = []string{proto2} 755 extraConf := &ExtraConfig{ 756 AlternativeRecordLayer: &recordLayer{in: sIn, out: sOut}, 757 Accept0RTT: func([]byte) bool { return true }, 758 MaxEarlyData: 1, 759 } 760 server := Server(&unusedConn{remoteAddr: raddr}, serverConf, extraConf) 761 defer server.Close() 762 errChan <- server.Handshake() 763 connStateChan <- server.ConnectionStateWith0RTT() 764 }() 765 766 clientConf.NextProtos = []string{proto2} 767 extraConf.Enable0RTT = true 768 var rejected bool 769 extraConf.Rejected0RTT = func() { rejected = true } 770 client = Client(&unusedConn{remoteAddr: raddr}, clientConf, extraConf) 771 if err := client.Handshake(); err != nil { 772 t.Fatalf("second handshake failed %s", err) 773 } 774 defer client.Close() 775 if err := <-errChan; err != nil { 776 t.Fatalf("second handshake failed %s", err) 777 } 778 if expectReject { 779 if !rejected { 780 t.Fatal("expected 0-RTT to be rejected") 781 } 782 if client.ConnectionStateWith0RTT().Used0RTT { 783 t.Fatal("expected 0-RTT to be rejected") 784 } 785 if (<-connStateChan).Used0RTT { 786 t.Fatal("expected 0-RTT to be rejected") 787 } 788 } else { 789 if rejected { 790 t.Fatal("didn't expect 0-RTT to be rejected") 791 } 792 if !client.ConnectionStateWith0RTT().Used0RTT { 793 t.Fatal("didn't expect 0-RTT to be rejected") 794 } 795 if !(<-connStateChan).Used0RTT { 796 t.Fatal("didn't expect 0-RTT to be rejected") 797 } 798 } 799 } 800 801 t.Run("with the same alpn", func(t *testing.T) { 802 run(t, "proto1", "proto1", false) 803 }) 804 t.Run("with different alpn", func(t *testing.T) { 805 run(t, "proto1", "proto2", true) 806 }) 807} 808