1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package qtls
6
7import (
8	"bytes"
9	"crypto/rsa"
10	"crypto/x509"
11	"encoding/base64"
12	"encoding/binary"
13	"encoding/pem"
14	"errors"
15	"fmt"
16	"io"
17	"math/big"
18	"net"
19	"os"
20	"os/exec"
21	"path/filepath"
22	"reflect"
23	"strconv"
24	"strings"
25	"testing"
26	"time"
27
28	"github.com/golang/mock/gomock"
29)
30
31// Note: see comment in handshake_test.go for details of how the reference
32// tests work.
33
34// opensslInputEvent enumerates possible inputs that can be sent to an `openssl
35// s_client` process.
36type opensslInputEvent int
37
38const (
39	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
40	// connection.
41	opensslRenegotiate opensslInputEvent = iota
42
43	// opensslSendBanner causes OpenSSL to send the contents of
44	// opensslSentinel on the connection.
45	opensslSendSentinel
46
47	// opensslKeyUpdate causes OpenSSL to send send a key update message to the
48	// client and request one back.
49	opensslKeyUpdate
50)
51
52const opensslSentinel = "SENTINEL\n"
53
54type opensslInput chan opensslInputEvent
55
56func (i opensslInput) Read(buf []byte) (n int, err error) {
57	for event := range i {
58		switch event {
59		case opensslRenegotiate:
60			return copy(buf, []byte("R\n")), nil
61		case opensslKeyUpdate:
62			return copy(buf, []byte("K\n")), nil
63		case opensslSendSentinel:
64			return copy(buf, []byte(opensslSentinel)), nil
65		default:
66			panic("unknown event")
67		}
68	}
69
70	return 0, io.EOF
71}
72
73// opensslOutputSink is an io.Writer that receives the stdout and stderr from an
74// `openssl` process and sends a value to handshakeComplete or readKeyUpdate
75// when certain messages are seen.
76type opensslOutputSink struct {
77	handshakeComplete chan struct{}
78	readKeyUpdate     chan struct{}
79	all               []byte
80	line              []byte
81}
82
83func newOpensslOutputSink() *opensslOutputSink {
84	return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
85}
86
87// opensslEndOfHandshake is a message that the “openssl s_server” tool will
88// print when a handshake completes if run with “-state”.
89const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
90
91// opensslReadKeyUpdate is a message that the “openssl s_server” tool will
92// print when a KeyUpdate message is received if run with “-state”.
93const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
94
95func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
96	o.line = append(o.line, data...)
97	o.all = append(o.all, data...)
98
99	for {
100		i := bytes.IndexByte(o.line, '\n')
101		if i < 0 {
102			break
103		}
104
105		if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
106			o.handshakeComplete <- struct{}{}
107		}
108		if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) {
109			o.readKeyUpdate <- struct{}{}
110		}
111		o.line = o.line[i+1:]
112	}
113
114	return len(data), nil
115}
116
117func (o *opensslOutputSink) String() string {
118	return string(o.all)
119}
120
121// clientTest represents a test of the TLS client handshake against a reference
122// implementation.
123type clientTest struct {
124	// name is a freeform string identifying the test and the file in which
125	// the expected results will be stored.
126	name string
127	// args, if not empty, contains a series of arguments for the
128	// command to run for the reference server.
129	args []string
130	// config, if not nil, contains a custom Config to use for this test.
131	config *Config
132	// extraConfig contains a custom ExtraConfig to use for this test.
133	extraConfig *ExtraConfig
134	// cert, if not empty, contains a DER-encoded certificate for the
135	// reference server.
136	cert []byte
137	// key, if not nil, contains either a *rsa.PrivateKey, ed25519.PrivateKey or
138	// *ecdsa.PrivateKey which is the private key for the reference server.
139	key interface{}
140	// extensions, if not nil, contains a list of extension data to be returned
141	// from the ServerHello. The data should be in standard TLS format with
142	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
143	extensions [][]byte
144	// validate, if not nil, is a function that will be called with the
145	// ConnectionState of the resulting connection. It returns a non-nil
146	// error if the ConnectionState is unacceptable.
147	validate func(ConnectionState) error
148	// numRenegotiations is the number of times that the connection will be
149	// renegotiated.
150	numRenegotiations int
151	// renegotiationExpectedToFail, if not zero, is the number of the
152	// renegotiation attempt that is expected to fail.
153	renegotiationExpectedToFail int
154	// checkRenegotiationError, if not nil, is called with any error
155	// arising from renegotiation. It can map expected errors to nil to
156	// ignore them.
157	checkRenegotiationError func(renegotiationNum int, err error) error
158	// sendKeyUpdate will cause the server to send a KeyUpdate message.
159	sendKeyUpdate bool
160}
161
162var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
163
164// connFromCommand starts the reference server process, connects to it and
165// returns a recordingConn for the connection. The stdin return value is an
166// opensslInput for the stdin of the child process. It must be closed before
167// Waiting for child.
168func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
169	cert := testRSACertificate
170	if len(test.cert) > 0 {
171		cert = test.cert
172	}
173	certPath := tempFile(string(cert))
174	defer os.Remove(certPath)
175
176	var key interface{} = testRSAPrivateKey
177	if test.key != nil {
178		key = test.key
179	}
180	derBytes, err := x509.MarshalPKCS8PrivateKey(key)
181	if err != nil {
182		panic(err)
183	}
184
185	var pemOut bytes.Buffer
186	pem.Encode(&pemOut, &pem.Block{Type: "PRIVATE KEY", Bytes: derBytes})
187
188	keyPath := tempFile(pemOut.String())
189	defer os.Remove(keyPath)
190
191	var command []string
192	command = append(command, serverCommand...)
193	command = append(command, test.args...)
194	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
195	// serverPort contains the port that OpenSSL will listen on. OpenSSL
196	// can't take "0" as an argument here so we have to pick a number and
197	// hope that it's not in use on the machine. Since this only occurs
198	// when -update is given and thus when there's a human watching the
199	// test, this isn't too bad.
200	const serverPort = 24323
201	command = append(command, "-accept", strconv.Itoa(serverPort))
202
203	if len(test.extensions) > 0 {
204		var serverInfo bytes.Buffer
205		for _, ext := range test.extensions {
206			pem.Encode(&serverInfo, &pem.Block{
207				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
208				Bytes: ext,
209			})
210		}
211		serverInfoPath := tempFile(serverInfo.String())
212		defer os.Remove(serverInfoPath)
213		command = append(command, "-serverinfo", serverInfoPath)
214	}
215
216	if test.numRenegotiations > 0 || test.sendKeyUpdate {
217		found := false
218		for _, flag := range command[1:] {
219			if flag == "-state" {
220				found = true
221				break
222			}
223		}
224
225		if !found {
226			panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
227		}
228	}
229
230	cmd := exec.Command(command[0], command[1:]...)
231	stdin = opensslInput(make(chan opensslInputEvent))
232	cmd.Stdin = stdin
233	out := newOpensslOutputSink()
234	cmd.Stdout = out
235	cmd.Stderr = out
236	if err := cmd.Start(); err != nil {
237		return nil, nil, nil, nil, err
238	}
239
240	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
241	// opening the listening socket, so we can't use that to wait until it
242	// has started listening. Thus we are forced to poll until we get a
243	// connection.
244	var tcpConn net.Conn
245	for i := uint(0); i < 5; i++ {
246		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
247			IP:   net.IPv4(127, 0, 0, 1),
248			Port: serverPort,
249		})
250		if err == nil {
251			break
252		}
253		time.Sleep((1 << i) * 5 * time.Millisecond)
254	}
255	if err != nil {
256		close(stdin)
257		cmd.Process.Kill()
258		err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
259		return nil, nil, nil, nil, err
260	}
261
262	record := &recordingConn{
263		Conn: tcpConn,
264	}
265
266	return record, cmd, stdin, out, nil
267}
268
269func (test *clientTest) dataPath() string {
270	return filepath.Join("testdata", "Client-"+test.name)
271}
272
273func (test *clientTest) loadData() (flows [][]byte, err error) {
274	in, err := os.Open(test.dataPath())
275	if err != nil {
276		return nil, err
277	}
278	defer in.Close()
279	return parseTestData(in)
280}
281
282func (test *clientTest) run(t *testing.T, write bool) {
283	var clientConn, serverConn net.Conn
284	var recordingConn *recordingConn
285	var childProcess *exec.Cmd
286	var stdin opensslInput
287	var stdout *opensslOutputSink
288
289	if write {
290		var err error
291		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
292		if err != nil {
293			t.Fatalf("Failed to start subcommand: %s", err)
294		}
295		clientConn = recordingConn
296		defer func() {
297			if t.Failed() {
298				t.Logf("OpenSSL output:\n\n%s", stdout.all)
299			}
300		}()
301	} else {
302		clientConn, serverConn = localPipe(t)
303	}
304
305	doneChan := make(chan bool)
306	defer func() {
307		clientConn.Close()
308		<-doneChan
309	}()
310	go func() {
311		defer close(doneChan)
312
313		config := test.config
314		if config == nil {
315			config = testConfig
316		}
317		client := Client(clientConn, config, test.extraConfig)
318		defer client.Close()
319
320		if _, err := client.Write([]byte("hello\n")); err != nil {
321			t.Errorf("Client.Write failed: %s", err)
322			return
323		}
324
325		for i := 1; i <= test.numRenegotiations; i++ {
326			// The initial handshake will generate a
327			// handshakeComplete signal which needs to be quashed.
328			if i == 1 && write {
329				<-stdout.handshakeComplete
330			}
331
332			// OpenSSL will try to interleave application data and
333			// a renegotiation if we send both concurrently.
334			// Therefore: ask OpensSSL to start a renegotiation, run
335			// a goroutine to call client.Read and thus process the
336			// renegotiation request, watch for OpenSSL's stdout to
337			// indicate that the handshake is complete and,
338			// finally, have OpenSSL write something to cause
339			// client.Read to complete.
340			if write {
341				stdin <- opensslRenegotiate
342			}
343
344			signalChan := make(chan struct{})
345
346			go func() {
347				defer close(signalChan)
348
349				buf := make([]byte, 256)
350				n, err := client.Read(buf)
351
352				if test.checkRenegotiationError != nil {
353					newErr := test.checkRenegotiationError(i, err)
354					if err != nil && newErr == nil {
355						return
356					}
357					err = newErr
358				}
359
360				if err != nil {
361					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
362					return
363				}
364
365				buf = buf[:n]
366				if !bytes.Equal([]byte(opensslSentinel), buf) {
367					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
368				}
369
370				if expected := i + 1; client.handshakes != expected {
371					t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
372				}
373			}()
374
375			if write && test.renegotiationExpectedToFail != i {
376				<-stdout.handshakeComplete
377				stdin <- opensslSendSentinel
378			}
379			<-signalChan
380		}
381
382		if test.sendKeyUpdate {
383			if write {
384				<-stdout.handshakeComplete
385				stdin <- opensslKeyUpdate
386			}
387
388			doneRead := make(chan struct{})
389
390			go func() {
391				defer close(doneRead)
392
393				buf := make([]byte, 256)
394				n, err := client.Read(buf)
395
396				if err != nil {
397					t.Errorf("Client.Read failed after KeyUpdate: %s", err)
398					return
399				}
400
401				buf = buf[:n]
402				if !bytes.Equal([]byte(opensslSentinel), buf) {
403					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
404				}
405			}()
406
407			if write {
408				// There's no real reason to wait for the client KeyUpdate to
409				// send data with the new server keys, except that s_server
410				// drops writes if they are sent at the wrong time.
411				<-stdout.readKeyUpdate
412				stdin <- opensslSendSentinel
413			}
414			<-doneRead
415
416			if _, err := client.Write([]byte("hello again\n")); err != nil {
417				t.Errorf("Client.Write failed: %s", err)
418				return
419			}
420		}
421
422		if test.validate != nil {
423			if err := test.validate(client.ConnectionState()); err != nil {
424				t.Errorf("validate callback returned error: %s", err)
425			}
426		}
427
428		// If the server sent us an alert after our last flight, give it a
429		// chance to arrive.
430		if write && test.renegotiationExpectedToFail == 0 {
431			if err := peekError(client); err != nil {
432				t.Errorf("final Read returned an error: %s", err)
433			}
434		}
435	}()
436
437	if !write {
438		flows, err := test.loadData()
439		if err != nil {
440			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
441		}
442		for i, b := range flows {
443			if i%2 == 1 {
444				if *fast {
445					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
446				} else {
447					serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
448				}
449				serverConn.Write(b)
450				continue
451			}
452			bb := make([]byte, len(b))
453			if *fast {
454				serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
455			} else {
456				serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
457			}
458			_, err := io.ReadFull(serverConn, bb)
459			if err != nil {
460				t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
461			}
462			if !bytes.Equal(b, bb) {
463				t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
464			}
465		}
466	}
467
468	<-doneChan
469	if !write {
470		serverConn.Close()
471	}
472
473	if write {
474		path := test.dataPath()
475		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
476		if err != nil {
477			t.Fatalf("Failed to create output file: %s", err)
478		}
479		defer out.Close()
480		recordingConn.Close()
481		close(stdin)
482		childProcess.Process.Kill()
483		childProcess.Wait()
484		if len(recordingConn.flows) < 3 {
485			t.Fatalf("Client connection didn't work")
486		}
487		recordingConn.WriteTo(out)
488		t.Logf("Wrote %s\n", path)
489	}
490}
491
492// peekError does a read with a short timeout to check if the next read would
493// cause an error, for example if there is an alert waiting on the wire.
494func peekError(conn net.Conn) error {
495	conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
496	if n, err := conn.Read(make([]byte, 1)); n != 0 {
497		return errors.New("unexpectedly read data")
498	} else if err != nil {
499		if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
500			return err
501		}
502	}
503	return nil
504}
505
506func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
507	// Make a deep copy of the template before going parallel.
508	test := *template
509	if template.config != nil {
510		test.config = template.config.Clone()
511	}
512	test.name = version + "-" + test.name
513	test.args = append([]string{option}, test.args...)
514
515	runTestAndUpdateIfNeeded(t, version, test.run, false)
516}
517
518func runClientTestTLS10(t *testing.T, template *clientTest) {
519	runClientTestForVersion(t, template, "TLSv10", "-tls1")
520}
521
522func runClientTestTLS11(t *testing.T, template *clientTest) {
523	runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
524}
525
526func runClientTestTLS12(t *testing.T, template *clientTest) {
527	runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
528}
529
530func runClientTestTLS13(t *testing.T, template *clientTest) {
531	runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
532}
533
534func TestHandshakeClientRSARC4(t *testing.T) {
535	test := &clientTest{
536		name: "RSA-RC4",
537		args: []string{"-cipher", "RC4-SHA"},
538	}
539	runClientTestTLS10(t, test)
540	runClientTestTLS11(t, test)
541	runClientTestTLS12(t, test)
542}
543
544func TestHandshakeClientRSAAES128GCM(t *testing.T) {
545	test := &clientTest{
546		name: "AES128-GCM-SHA256",
547		args: []string{"-cipher", "AES128-GCM-SHA256"},
548	}
549	runClientTestTLS12(t, test)
550}
551
552func TestHandshakeClientRSAAES256GCM(t *testing.T) {
553	test := &clientTest{
554		name: "AES256-GCM-SHA384",
555		args: []string{"-cipher", "AES256-GCM-SHA384"},
556	}
557	runClientTestTLS12(t, test)
558}
559
560func TestHandshakeClientECDHERSAAES(t *testing.T) {
561	test := &clientTest{
562		name: "ECDHE-RSA-AES",
563		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
564	}
565	runClientTestTLS10(t, test)
566	runClientTestTLS11(t, test)
567	runClientTestTLS12(t, test)
568}
569
570func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
571	test := &clientTest{
572		name: "ECDHE-ECDSA-AES",
573		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
574		cert: testECDSACertificate,
575		key:  testECDSAPrivateKey,
576	}
577	runClientTestTLS10(t, test)
578	runClientTestTLS11(t, test)
579	runClientTestTLS12(t, test)
580}
581
582func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
583	test := &clientTest{
584		name: "ECDHE-ECDSA-AES-GCM",
585		args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
586		cert: testECDSACertificate,
587		key:  testECDSAPrivateKey,
588	}
589	runClientTestTLS12(t, test)
590}
591
592func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
593	test := &clientTest{
594		name: "ECDHE-ECDSA-AES256-GCM-SHA384",
595		args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
596		cert: testECDSACertificate,
597		key:  testECDSAPrivateKey,
598	}
599	runClientTestTLS12(t, test)
600}
601
602func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
603	test := &clientTest{
604		name: "AES128-SHA256",
605		args: []string{"-cipher", "AES128-SHA256"},
606	}
607	runClientTestTLS12(t, test)
608}
609
610func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
611	test := &clientTest{
612		name: "ECDHE-RSA-AES128-SHA256",
613		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
614	}
615	runClientTestTLS12(t, test)
616}
617
618func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
619	test := &clientTest{
620		name: "ECDHE-ECDSA-AES128-SHA256",
621		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
622		cert: testECDSACertificate,
623		key:  testECDSAPrivateKey,
624	}
625	runClientTestTLS12(t, test)
626}
627
628func TestHandshakeClientX25519(t *testing.T) {
629	config := testConfig.Clone()
630	config.CurvePreferences = []CurveID{X25519}
631
632	test := &clientTest{
633		name:   "X25519-ECDHE",
634		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
635		config: config,
636	}
637
638	runClientTestTLS12(t, test)
639	runClientTestTLS13(t, test)
640}
641
642func TestHandshakeClientP256(t *testing.T) {
643	config := testConfig.Clone()
644	config.CurvePreferences = []CurveID{CurveP256}
645
646	test := &clientTest{
647		name:   "P256-ECDHE",
648		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
649		config: config,
650	}
651
652	runClientTestTLS12(t, test)
653	runClientTestTLS13(t, test)
654}
655
656func TestHandshakeClientHelloRetryRequest(t *testing.T) {
657	config := testConfig.Clone()
658	config.CurvePreferences = []CurveID{X25519, CurveP256}
659
660	test := &clientTest{
661		name:        "HelloRetryRequest",
662		args:        []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
663		config:      config,
664		extraConfig: &ExtraConfig{Rejected0RTT: func() { t.Error("didn't expect 0-RTT rejection") }},
665	}
666
667	runClientTestTLS13(t, test)
668}
669
670func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
671	config := testConfig.Clone()
672	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
673
674	test := &clientTest{
675		name:   "ECDHE-RSA-CHACHA20-POLY1305",
676		args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
677		config: config,
678	}
679
680	runClientTestTLS12(t, test)
681}
682
683func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
684	config := testConfig.Clone()
685	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
686
687	test := &clientTest{
688		name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
689		args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
690		config: config,
691		cert:   testECDSACertificate,
692		key:    testECDSAPrivateKey,
693	}
694
695	runClientTestTLS12(t, test)
696}
697
698func TestHandshakeClientAES128SHA256(t *testing.T) {
699	test := &clientTest{
700		name: "AES128-SHA256",
701		args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
702	}
703	runClientTestTLS13(t, test)
704}
705func TestHandshakeClientAES256SHA384(t *testing.T) {
706	test := &clientTest{
707		name: "AES256-SHA384",
708		args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
709	}
710	runClientTestTLS13(t, test)
711}
712func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
713	test := &clientTest{
714		name: "CHACHA20-SHA256",
715		args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
716	}
717	runClientTestTLS13(t, test)
718}
719
720func TestHandshakeClientECDSATLS13(t *testing.T) {
721	test := &clientTest{
722		name: "ECDSA",
723		cert: testECDSACertificate,
724		key:  testECDSAPrivateKey,
725	}
726	runClientTestTLS13(t, test)
727}
728
729func TestHandshakeClientEd25519(t *testing.T) {
730	test := &clientTest{
731		name: "Ed25519",
732		cert: testEd25519Certificate,
733		key:  testEd25519PrivateKey,
734	}
735	runClientTestTLS12(t, test)
736	runClientTestTLS13(t, test)
737
738	config := testConfig.Clone()
739	cert, _ := X509KeyPair([]byte(clientEd25519CertificatePEM), []byte(clientEd25519KeyPEM))
740	config.Certificates = []Certificate{cert}
741
742	test = &clientTest{
743		name:   "ClientCert-Ed25519",
744		args:   []string{"-Verify", "1"},
745		config: config,
746	}
747
748	runClientTestTLS12(t, test)
749	runClientTestTLS13(t, test)
750}
751
752func TestHandshakeClientCertRSA(t *testing.T) {
753	config := testConfig.Clone()
754	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
755	config.Certificates = []Certificate{cert}
756
757	test := &clientTest{
758		name:   "ClientCert-RSA-RSA",
759		args:   []string{"-cipher", "AES128", "-Verify", "1"},
760		config: config,
761	}
762
763	runClientTestTLS10(t, test)
764	runClientTestTLS12(t, test)
765
766	test = &clientTest{
767		name:   "ClientCert-RSA-ECDSA",
768		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
769		config: config,
770		cert:   testECDSACertificate,
771		key:    testECDSAPrivateKey,
772	}
773
774	runClientTestTLS10(t, test)
775	runClientTestTLS12(t, test)
776	runClientTestTLS13(t, test)
777
778	test = &clientTest{
779		name:   "ClientCert-RSA-AES256-GCM-SHA384",
780		args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
781		config: config,
782		cert:   testRSACertificate,
783		key:    testRSAPrivateKey,
784	}
785
786	runClientTestTLS12(t, test)
787}
788
789func TestHandshakeClientCertECDSA(t *testing.T) {
790	config := testConfig.Clone()
791	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
792	config.Certificates = []Certificate{cert}
793
794	test := &clientTest{
795		name:   "ClientCert-ECDSA-RSA",
796		args:   []string{"-cipher", "AES128", "-Verify", "1"},
797		config: config,
798	}
799
800	runClientTestTLS10(t, test)
801	runClientTestTLS12(t, test)
802	runClientTestTLS13(t, test)
803
804	test = &clientTest{
805		name:   "ClientCert-ECDSA-ECDSA",
806		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
807		config: config,
808		cert:   testECDSACertificate,
809		key:    testECDSAPrivateKey,
810	}
811
812	runClientTestTLS10(t, test)
813	runClientTestTLS12(t, test)
814}
815
816// TestHandshakeClientCertRSAPSS tests rsa_pss_rsae_sha256 signatures from both
817// client and server certificates. It also serves from both sides a certificate
818// signed itself with RSA-PSS, mostly to check that crypto/x509 chain validation
819// works.
820func TestHandshakeClientCertRSAPSS(t *testing.T) {
821	cert, err := x509.ParseCertificate(testRSAPSSCertificate)
822	if err != nil {
823		panic(err)
824	}
825	rootCAs := x509.NewCertPool()
826	rootCAs.AddCert(cert)
827
828	config := testConfig.Clone()
829	// Use GetClientCertificate to bypass the client certificate selection logic.
830	config.GetClientCertificate = func(*CertificateRequestInfo) (*Certificate, error) {
831		return &Certificate{
832			Certificate: [][]byte{testRSAPSSCertificate},
833			PrivateKey:  testRSAPrivateKey,
834		}, nil
835	}
836	config.RootCAs = rootCAs
837
838	test := &clientTest{
839		name: "ClientCert-RSA-RSAPSS",
840		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
841			"rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
842		config: config,
843		cert:   testRSAPSSCertificate,
844		key:    testRSAPrivateKey,
845	}
846	runClientTestTLS12(t, test)
847	runClientTestTLS13(t, test)
848}
849
850func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
851	config := testConfig.Clone()
852	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
853	config.Certificates = []Certificate{cert}
854
855	test := &clientTest{
856		name: "ClientCert-RSA-RSAPKCS1v15",
857		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
858			"rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
859		config: config,
860	}
861
862	runClientTestTLS12(t, test)
863}
864
865func TestClientKeyUpdate(t *testing.T) {
866	test := &clientTest{
867		name:          "KeyUpdate",
868		args:          []string{"-state"},
869		sendKeyUpdate: true,
870	}
871	runClientTestTLS13(t, test)
872}
873
874func TestResumption(t *testing.T) {
875	t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12, false) })
876	t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13, false) })
877	t.Run("TLSv13, saving app data", func(t *testing.T) { testResumption(t, VersionTLS13, true) })
878	t.Run("TLSv13, with 0-RTT", func(t *testing.T) { testResumption0RTT(t, false) })
879	t.Run("TLSv13, with 0-RTT, saving app data", func(t *testing.T) { testResumption0RTT(t, true) })
880}
881
882func testResumption(t *testing.T, version uint16, saveAppData bool) {
883	if testing.Short() {
884		t.Skip("skipping in -short mode")
885	}
886	serverConfig := &Config{
887		MaxVersion:   version,
888		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
889		Certificates: testConfig.Certificates,
890	}
891
892	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
893	if err != nil {
894		panic(err)
895	}
896
897	rootCAs := x509.NewCertPool()
898	rootCAs.AddCert(issuer)
899
900	var restoredAppData []byte
901	clientConfig := &Config{
902		MaxVersion:         version,
903		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
904		ClientSessionCache: NewLRUClientSessionCache(32),
905		RootCAs:            rootCAs,
906		ServerName:         "example.golang",
907	}
908	clientExtraConfig := &ExtraConfig{}
909	if saveAppData {
910		clientExtraConfig.GetAppDataForSessionState = func() []byte { return []byte("foobar") }
911		clientExtraConfig.SetAppDataFromSessionState = func(data []byte) { restoredAppData = data }
912	}
913
914	testResumeState := func(test string, didResume bool) {
915		_, hs, err := testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, nil)
916		if err != nil {
917			t.Fatalf("%s: handshake failed: %s", test, err)
918		}
919		if hs.DidResume != didResume {
920			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
921		}
922		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
923			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
924		}
925		if got, want := hs.ServerName, clientConfig.ServerName; got != want {
926			t.Errorf("%s: server name %s, want %s", test, got, want)
927		}
928		if didResume && saveAppData {
929			if !bytes.Equal(restoredAppData, []byte("foobar")) {
930				t.Fatalf("Expected to restore app data saved with the session state. Got: %#v", restoredAppData)
931			}
932			restoredAppData = nil
933		}
934	}
935
936	getTicket := func() []byte {
937		return fromClientSessionState(clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state).sessionTicket
938	}
939	deleteTicket := func() {
940		ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
941		clientConfig.ClientSessionCache.Put(ticketKey, nil)
942	}
943	corruptTicket := func() {
944		fromClientSessionState(clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state).masterSecret[0] ^= 0xff
945	}
946	randomKey := func() [32]byte {
947		var k [32]byte
948		if _, err := io.ReadFull(fromConfig(serverConfig).rand(), k[:]); err != nil {
949			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
950		}
951		return k
952	}
953
954	testResumeState("Handshake", false)
955	ticket := getTicket()
956	testResumeState("Resume", true)
957	if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
958		t.Fatal("first ticket doesn't match ticket after resumption")
959	}
960	if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
961		t.Fatal("ticket didn't change after resumption")
962	}
963
964	// An old session ticket can resume, but the server will provide a ticket encrypted with a fresh key.
965	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
966	testResumeState("ResumeWithOldTicket", true)
967	if bytes.Equal(ticket[:ticketKeyNameLen], getTicket()[:ticketKeyNameLen]) {
968		t.Fatal("old first ticket matches the fresh one")
969	}
970
971	// Now the session tickey key is expired, so a full handshake should occur.
972	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
973	testResumeState("ResumeWithExpiredTicket", false)
974	if bytes.Equal(ticket, getTicket()) {
975		t.Fatal("expired first ticket matches the fresh one")
976	}
977
978	serverConfig.Time = func() time.Time { return time.Now() } // reset the time back
979	key1 := randomKey()
980	serverConfig.SetSessionTicketKeys([][32]byte{key1})
981
982	testResumeState("InvalidSessionTicketKey", false)
983	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
984
985	key2 := randomKey()
986	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
987	ticket = getTicket()
988	testResumeState("KeyChange", true)
989	if bytes.Equal(ticket, getTicket()) {
990		t.Fatal("new ticket wasn't included while resuming")
991	}
992	testResumeState("KeyChangeFinish", true)
993
994	// Age the session ticket a bit, but not yet expired.
995	serverConfig.Time = func() time.Time { return time.Now().Add(24*time.Hour + time.Minute) }
996	testResumeState("OldSessionTicket", true)
997	ticket = getTicket()
998	// Expire the session ticket, which would force a full handshake.
999	serverConfig.Time = func() time.Time { return time.Now().Add(24*8*time.Hour + time.Minute) }
1000	testResumeState("ExpiredSessionTicket", false)
1001	if bytes.Equal(ticket, getTicket()) {
1002		t.Fatal("new ticket wasn't provided after old ticket expired")
1003	}
1004
1005	// Age the session ticket a bit at a time, but don't expire it.
1006	d := 0 * time.Hour
1007	for i := 0; i < 13; i++ {
1008		d += 12 * time.Hour
1009		serverConfig.Time = func() time.Time { return time.Now().Add(d) }
1010		testResumeState("OldSessionTicket", true)
1011	}
1012	// Expire it (now a little more than 7 days) and make sure a full
1013	// handshake occurs for TLS 1.2. Resumption should still occur for
1014	// TLS 1.3 since the client should be using a fresh ticket sent over
1015	// by the server.
1016	d += 12 * time.Hour
1017	serverConfig.Time = func() time.Time { return time.Now().Add(d) }
1018	if version == VersionTLS13 {
1019		testResumeState("ExpiredSessionTicket", true)
1020	} else {
1021		testResumeState("ExpiredSessionTicket", false)
1022	}
1023	if bytes.Equal(ticket, getTicket()) {
1024		t.Fatal("new ticket wasn't provided after old ticket expired")
1025	}
1026
1027	// Reset serverConfig to ensure that calling SetSessionTicketKeys
1028	// before the serverConfig is used works.
1029	serverConfig = &Config{
1030		MaxVersion:   version,
1031		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
1032		Certificates: testConfig.Certificates,
1033	}
1034	serverConfig.SetSessionTicketKeys([][32]byte{key2})
1035
1036	testResumeState("FreshConfig", true)
1037
1038	// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
1039	// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
1040	if version != VersionTLS13 {
1041		clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
1042		testResumeState("DifferentCipherSuite", false)
1043		testResumeState("DifferentCipherSuiteRecovers", true)
1044	}
1045
1046	deleteTicket()
1047	testResumeState("WithoutSessionTicket", false)
1048
1049	// Session resumption should work when using client certificates
1050	deleteTicket()
1051	serverConfig.ClientCAs = rootCAs
1052	serverConfig.ClientAuth = RequireAndVerifyClientCert
1053	clientConfig.Certificates = serverConfig.Certificates
1054	testResumeState("InitialHandshake", false)
1055	testResumeState("WithClientCertificates", true)
1056	serverConfig.ClientAuth = NoClientCert
1057
1058	// Tickets should be removed from the session cache on TLS handshake
1059	// failure, and the client should recover from a corrupted PSK
1060	testResumeState("FetchTicketToCorrupt", false)
1061	corruptTicket()
1062	_, _, err = testHandshake(t, clientConfig, serverConfig)
1063	if err == nil {
1064		t.Fatalf("handshake did not fail with a corrupted client secret")
1065	}
1066	testResumeState("AfterHandshakeFailure", false)
1067
1068	clientConfig.ClientSessionCache = nil
1069	testResumeState("WithoutSessionCache", false)
1070}
1071
1072func testResumption0RTT(t *testing.T, saveAppData bool) {
1073	mockCtrl := gomock.NewController(t)
1074	defer mockCtrl.Finish()
1075
1076	serverConfig := testConfig.Clone()
1077	serverExtraConfig := &ExtraConfig{
1078		MaxEarlyData: 100,
1079		Accept0RTT:   func([]byte) bool { return true },
1080	}
1081
1082	cache := NewMockClientSessionCache(mockCtrl)
1083	clientConfig := testConfig.Clone()
1084	clientConfig.ClientSessionCache = cache
1085	clientExtraConfig := &ExtraConfig{Enable0RTT: true}
1086	var restoredAppData []byte
1087	if saveAppData {
1088		clientExtraConfig.GetAppDataForSessionState = func() []byte { return []byte("foobar") }
1089		clientExtraConfig.SetAppDataFromSessionState = func(data []byte) { restoredAppData = data }
1090	}
1091
1092	// check that the ticket is deleted when 0-RTT is used
1093	var state *ClientSessionState
1094	gomock.InOrder(
1095		cache.EXPECT().Get(gomock.Any()),
1096		cache.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, s *ClientSessionState) {
1097			state = s
1098		}),
1099	)
1100	_, _, err := testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, serverExtraConfig)
1101	if err != nil {
1102		t.Fatalf("first handshake failed: %s", err)
1103	}
1104
1105	gomock.InOrder(
1106		cache.EXPECT().Get(gomock.Any()).Return(state, true),
1107		cache.EXPECT().Put(gomock.Any(), nil), // expect the ticket to be deleted immediately
1108		cache.EXPECT().Put(gomock.Any(), gomock.Any()),
1109	)
1110	_, hs, err := testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, serverExtraConfig)
1111	if err != nil {
1112		t.Fatalf("second handshake failed: %s", err)
1113	}
1114	if !hs.Used0RTT {
1115		t.Fatal("should have used 0-RTT during the second handshake")
1116	}
1117
1118	// check that the ticket is not deleted when 0-RTT is not used
1119	clientExtraConfig.Enable0RTT = false
1120	gomock.InOrder(
1121		cache.EXPECT().Get(gomock.Any()),
1122		cache.EXPECT().Put(gomock.Any(), gomock.Any()).Do(func(_ string, s *ClientSessionState) {
1123			state = s
1124		}),
1125	)
1126	_, _, err = testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, serverExtraConfig)
1127	if err != nil {
1128		t.Fatalf("first handshake failed: %s", err)
1129	}
1130
1131	gomock.InOrder(
1132		cache.EXPECT().Get(gomock.Any()).Return(state, true),
1133		cache.EXPECT().Put(gomock.Any(), gomock.Any()),
1134	)
1135	_, hs, err = testHandshakeWithExtraConfig(t, clientConfig, clientExtraConfig, serverConfig, serverExtraConfig)
1136	if err != nil {
1137		t.Fatalf("second handshake failed: %s", err)
1138	}
1139	if hs.Used0RTT {
1140		t.Fatal("should not have used 0-RTT during the second handshake")
1141	}
1142	if saveAppData && !bytes.Equal(restoredAppData, []byte("foobar")) {
1143		t.Fatalf("expected app data to be restored. Got: %#v", restoredAppData)
1144	}
1145}
1146
1147func TestLRUClientSessionCache(t *testing.T) {
1148	// Initialize cache of capacity 4.
1149	cache := NewLRUClientSessionCache(4)
1150	cs := make([]ClientSessionState, 6)
1151	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1152
1153	// Add 4 entries to the cache and look them up.
1154	for i := 0; i < 4; i++ {
1155		cache.Put(keys[i], &cs[i])
1156	}
1157	for i := 0; i < 4; i++ {
1158		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1159			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1160		}
1161	}
1162
1163	// Add 2 more entries to the cache. First 2 should be evicted.
1164	for i := 4; i < 6; i++ {
1165		cache.Put(keys[i], &cs[i])
1166	}
1167	for i := 0; i < 2; i++ {
1168		if s, ok := cache.Get(keys[i]); ok || s != nil {
1169			t.Fatalf("session cache should have evicted key: %s", keys[i])
1170		}
1171	}
1172
1173	// Touch entry 2. LRU should evict 3 next.
1174	cache.Get(keys[2])
1175	cache.Put(keys[0], &cs[0])
1176	if s, ok := cache.Get(keys[3]); ok || s != nil {
1177		t.Fatalf("session cache should have evicted key 3")
1178	}
1179
1180	// Update entry 0 in place.
1181	cache.Put(keys[0], &cs[3])
1182	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1183		t.Fatalf("session cache failed update for key 0")
1184	}
1185
1186	// Calling Put with a nil entry deletes the key.
1187	cache.Put(keys[0], nil)
1188	if _, ok := cache.Get(keys[0]); ok {
1189		t.Fatalf("session cache failed to delete key 0")
1190	}
1191
1192	// Delete entry 2. LRU should keep 4 and 5
1193	cache.Put(keys[2], nil)
1194	if _, ok := cache.Get(keys[2]); ok {
1195		t.Fatalf("session cache failed to delete key 4")
1196	}
1197	for i := 4; i < 6; i++ {
1198		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1199			t.Fatalf("session cache should not have deleted key: %s", keys[i])
1200		}
1201	}
1202}
1203
1204func TestKeyLogTLS12(t *testing.T) {
1205	var serverBuf, clientBuf bytes.Buffer
1206
1207	clientConfig := testConfig.Clone()
1208	clientConfig.KeyLogWriter = &clientBuf
1209	clientConfig.MaxVersion = VersionTLS12
1210
1211	serverConfig := testConfig.Clone()
1212	serverConfig.KeyLogWriter = &serverBuf
1213	serverConfig.MaxVersion = VersionTLS12
1214
1215	c, s := localPipe(t)
1216	done := make(chan bool)
1217
1218	go func() {
1219		defer close(done)
1220
1221		if err := Server(s, serverConfig, nil).Handshake(); err != nil {
1222			t.Errorf("server: %s", err)
1223			return
1224		}
1225		s.Close()
1226	}()
1227
1228	if err := Client(c, clientConfig, nil).Handshake(); err != nil {
1229		t.Fatalf("client: %s", err)
1230	}
1231
1232	c.Close()
1233	<-done
1234
1235	checkKeylogLine := func(side, loggedLine string) {
1236		if len(loggedLine) == 0 {
1237			t.Fatalf("%s: no keylog line was produced", side)
1238		}
1239		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
1240			1 /* space */ +
1241			32*2 /* hex client nonce */ +
1242			1 /* space */ +
1243			48*2 /* hex master secret */ +
1244			1 /* new line */
1245		if len(loggedLine) != expectedLen {
1246			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1247		}
1248		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1249			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1250		}
1251	}
1252
1253	checkKeylogLine("client", clientBuf.String())
1254	checkKeylogLine("server", serverBuf.String())
1255}
1256
1257func TestKeyLogTLS13(t *testing.T) {
1258	var serverBuf, clientBuf bytes.Buffer
1259
1260	clientConfig := testConfig.Clone()
1261	clientConfig.KeyLogWriter = &clientBuf
1262
1263	serverConfig := testConfig.Clone()
1264	serverConfig.KeyLogWriter = &serverBuf
1265
1266	c, s := localPipe(t)
1267	done := make(chan bool)
1268
1269	go func() {
1270		defer close(done)
1271
1272		if err := Server(s, serverConfig, nil).Handshake(); err != nil {
1273			t.Errorf("server: %s", err)
1274			return
1275		}
1276		s.Close()
1277	}()
1278
1279	if err := Client(c, clientConfig, nil).Handshake(); err != nil {
1280		t.Fatalf("client: %s", err)
1281	}
1282
1283	c.Close()
1284	<-done
1285
1286	checkKeylogLines := func(side, loggedLines string) {
1287		loggedLines = strings.TrimSpace(loggedLines)
1288		lines := strings.Split(loggedLines, "\n")
1289		if len(lines) != 4 {
1290			t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1291		}
1292	}
1293
1294	checkKeylogLines("client", clientBuf.String())
1295	checkKeylogLines("server", serverBuf.String())
1296}
1297
1298func TestHandshakeClientALPNMatch(t *testing.T) {
1299	config := testConfig.Clone()
1300	config.NextProtos = []string{"proto2", "proto1"}
1301
1302	test := &clientTest{
1303		name: "ALPN",
1304		// Note that this needs OpenSSL 1.0.2 because that is the first
1305		// version that supports the -alpn flag.
1306		args:   []string{"-alpn", "proto1,proto2"},
1307		config: config,
1308		validate: func(state ConnectionState) error {
1309			// The server's preferences should override the client.
1310			if state.NegotiatedProtocol != "proto1" {
1311				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1312			}
1313			return nil
1314		},
1315	}
1316	runClientTestTLS12(t, test)
1317	runClientTestTLS13(t, test)
1318}
1319
1320func TestHandshakeClientEnforceALPNMatch(t *testing.T) {
1321	clientConn, serverConn := localPipe(t)
1322	server := Server(serverConn, testConfig, nil)
1323
1324	sErrChan := make(chan error)
1325	go func() {
1326		sErrChan <- server.Handshake()
1327	}()
1328
1329	config := testConfig.Clone()
1330	config.NextProtos = []string{"proto2", "proto1"}
1331	extraConf := &ExtraConfig{EnforceNextProtoSelection: true}
1332
1333	client := Client(clientConn, config, extraConf)
1334	err := client.Handshake()
1335	if err == nil || err.Error() != "ALPN negotiation failed. Server didn't offer any protocols" {
1336		t.Fatalf("Expected APLN negotiation to fail, got %s", err)
1337	}
1338	sErr := <-sErrChan
1339	if sErr == nil || !strings.Contains(sErr.Error(), "no application protocol") {
1340		t.Fatalf("Expect 'no_application_protocol' error, got %s", sErr)
1341	}
1342}
1343
1344// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
1345const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1346
1347func TestHandshakClientSCTs(t *testing.T) {
1348	config := testConfig.Clone()
1349
1350	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1351	if err != nil {
1352		t.Fatal(err)
1353	}
1354
1355	// Note that this needs OpenSSL 1.0.2 because that is the first
1356	// version that supports the -serverinfo flag.
1357	test := &clientTest{
1358		name:       "SCT",
1359		config:     config,
1360		extensions: [][]byte{scts},
1361		validate: func(state ConnectionState) error {
1362			expectedSCTs := [][]byte{
1363				scts[8:125],
1364				scts[127:245],
1365				scts[247:],
1366			}
1367			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1368				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1369			}
1370			for i, expected := range expectedSCTs {
1371				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1372					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1373				}
1374			}
1375			return nil
1376		},
1377	}
1378	runClientTestTLS12(t, test)
1379
1380	// TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
1381	// supports ServerHello extensions.
1382}
1383
1384func TestRenegotiationRejected(t *testing.T) {
1385	config := testConfig.Clone()
1386	test := &clientTest{
1387		name:                        "RenegotiationRejected",
1388		args:                        []string{"-state"},
1389		config:                      config,
1390		numRenegotiations:           1,
1391		renegotiationExpectedToFail: 1,
1392		checkRenegotiationError: func(renegotiationNum int, err error) error {
1393			if err == nil {
1394				return errors.New("expected error from renegotiation but got nil")
1395			}
1396			if !strings.Contains(err.Error(), "no renegotiation") {
1397				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1398			}
1399			return nil
1400		},
1401	}
1402	runClientTestTLS12(t, test)
1403}
1404
1405func TestRenegotiateOnce(t *testing.T) {
1406	config := testConfig.Clone()
1407	config.Renegotiation = RenegotiateOnceAsClient
1408
1409	test := &clientTest{
1410		name:              "RenegotiateOnce",
1411		args:              []string{"-state"},
1412		config:            config,
1413		numRenegotiations: 1,
1414	}
1415
1416	runClientTestTLS12(t, test)
1417}
1418
1419func TestRenegotiateTwice(t *testing.T) {
1420	config := testConfig.Clone()
1421	config.Renegotiation = RenegotiateFreelyAsClient
1422
1423	test := &clientTest{
1424		name:              "RenegotiateTwice",
1425		args:              []string{"-state"},
1426		config:            config,
1427		numRenegotiations: 2,
1428	}
1429
1430	runClientTestTLS12(t, test)
1431}
1432
1433func TestRenegotiateTwiceRejected(t *testing.T) {
1434	config := testConfig.Clone()
1435	config.Renegotiation = RenegotiateOnceAsClient
1436
1437	test := &clientTest{
1438		name:                        "RenegotiateTwiceRejected",
1439		args:                        []string{"-state"},
1440		config:                      config,
1441		numRenegotiations:           2,
1442		renegotiationExpectedToFail: 2,
1443		checkRenegotiationError: func(renegotiationNum int, err error) error {
1444			if renegotiationNum == 1 {
1445				return err
1446			}
1447
1448			if err == nil {
1449				return errors.New("expected error from renegotiation but got nil")
1450			}
1451			if !strings.Contains(err.Error(), "no renegotiation") {
1452				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1453			}
1454			return nil
1455		},
1456	}
1457
1458	runClientTestTLS12(t, test)
1459}
1460
1461func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1462	test := &clientTest{
1463		name:   "ExportKeyingMaterial",
1464		config: testConfig.Clone(),
1465		validate: func(state ConnectionState) error {
1466			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1467				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1468			} else if len(km) != 42 {
1469				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1470			}
1471			return nil
1472		},
1473	}
1474	runClientTestTLS10(t, test)
1475	runClientTestTLS12(t, test)
1476	runClientTestTLS13(t, test)
1477}
1478
1479var hostnameInSNITests = []struct {
1480	in, out string
1481}{
1482	// Opaque string
1483	{"", ""},
1484	{"localhost", "localhost"},
1485	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
1486
1487	// DNS hostname
1488	{"golang.org", "golang.org"},
1489	{"golang.org.", "golang.org"},
1490
1491	// Literal IPv4 address
1492	{"1.2.3.4", ""},
1493
1494	// Literal IPv6 address
1495	{"::1", ""},
1496	{"::1%lo0", ""}, // with zone identifier
1497	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1498	{"[::1%lo0]", ""},
1499}
1500
1501func TestHostnameInSNI(t *testing.T) {
1502	for _, tt := range hostnameInSNITests {
1503		c, s := localPipe(t)
1504
1505		go func(host string) {
1506			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}, nil).Handshake()
1507		}(tt.in)
1508
1509		var header [5]byte
1510		if _, err := io.ReadFull(s, header[:]); err != nil {
1511			t.Fatal(err)
1512		}
1513		recordLen := int(header[3])<<8 | int(header[4])
1514
1515		record := make([]byte, recordLen)
1516		if _, err := io.ReadFull(s, record[:]); err != nil {
1517			t.Fatal(err)
1518		}
1519
1520		c.Close()
1521		s.Close()
1522
1523		var m clientHelloMsg
1524		if !m.unmarshal(record) {
1525			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1526			continue
1527		}
1528		if tt.in != tt.out && m.serverName == tt.in {
1529			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1530		}
1531		if m.serverName != tt.out {
1532			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1533		}
1534	}
1535}
1536
1537func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1538	// This checks that the server can't select a cipher suite that the
1539	// client didn't offer. See #13174.
1540
1541	c, s := localPipe(t)
1542	errChan := make(chan error, 1)
1543
1544	go func() {
1545		client := Client(c, &Config{
1546			ServerName:   "foo",
1547			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1548		}, nil)
1549		errChan <- client.Handshake()
1550	}()
1551
1552	var header [5]byte
1553	if _, err := io.ReadFull(s, header[:]); err != nil {
1554		t.Fatal(err)
1555	}
1556	recordLen := int(header[3])<<8 | int(header[4])
1557
1558	record := make([]byte, recordLen)
1559	if _, err := io.ReadFull(s, record); err != nil {
1560		t.Fatal(err)
1561	}
1562
1563	// Create a ServerHello that selects a different cipher suite than the
1564	// sole one that the client offered.
1565	serverHello := &serverHelloMsg{
1566		vers:        VersionTLS12,
1567		random:      make([]byte, 32),
1568		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1569	}
1570	serverHelloBytes := serverHello.marshal()
1571
1572	s.Write([]byte{
1573		byte(recordTypeHandshake),
1574		byte(VersionTLS12 >> 8),
1575		byte(VersionTLS12 & 0xff),
1576		byte(len(serverHelloBytes) >> 8),
1577		byte(len(serverHelloBytes)),
1578	})
1579	s.Write(serverHelloBytes)
1580	s.Close()
1581
1582	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1583		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1584	}
1585}
1586
1587func TestVerifyConnection(t *testing.T) {
1588	t.Run("TLSv12", func(t *testing.T) { testVerifyConnection(t, VersionTLS12) })
1589	t.Run("TLSv13", func(t *testing.T) { testVerifyConnection(t, VersionTLS13) })
1590}
1591
1592func testVerifyConnection(t *testing.T, version uint16) {
1593	checkFields := func(c ConnectionState, called *int, errorType string) error {
1594		if c.Version != version {
1595			return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version)
1596		}
1597		if c.HandshakeComplete {
1598			return fmt.Errorf("%s: got HandshakeComplete, want false", errorType)
1599		}
1600		if c.ServerName != "example.golang" {
1601			return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang")
1602		}
1603		if c.NegotiatedProtocol != "protocol1" {
1604			return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1")
1605		}
1606		if c.CipherSuite == 0 {
1607			return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType)
1608		}
1609		wantDidResume := false
1610		if *called == 2 { // if this is the second time, then it should be a resumption
1611			wantDidResume = true
1612		}
1613		if c.DidResume != wantDidResume {
1614			return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume)
1615		}
1616		return nil
1617	}
1618
1619	tests := []struct {
1620		name            string
1621		configureServer func(*Config, *int)
1622		configureClient func(*Config, *int)
1623	}{
1624		{
1625			name: "RequireAndVerifyClientCert",
1626			configureServer: func(config *Config, called *int) {
1627				config.ClientAuth = RequireAndVerifyClientCert
1628				config.VerifyConnection = func(c ConnectionState) error {
1629					*called++
1630					if l := len(c.PeerCertificates); l != 1 {
1631						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1632					}
1633					if len(c.VerifiedChains) == 0 {
1634						return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero")
1635					}
1636					return checkFields(c, called, "server")
1637				}
1638			},
1639			configureClient: func(config *Config, called *int) {
1640				config.VerifyConnection = func(c ConnectionState) error {
1641					*called++
1642					if l := len(c.PeerCertificates); l != 1 {
1643						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1644					}
1645					if len(c.VerifiedChains) == 0 {
1646						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1647					}
1648					if c.DidResume {
1649						return nil
1650						// The SCTs and OCSP Responce are dropped on resumption.
1651						// See http://golang.org/issue/39075.
1652					}
1653					if len(c.OCSPResponse) == 0 {
1654						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1655					}
1656					if len(c.SignedCertificateTimestamps) == 0 {
1657						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1658					}
1659					return checkFields(c, called, "client")
1660				}
1661			},
1662		},
1663		{
1664			name: "InsecureSkipVerify",
1665			configureServer: func(config *Config, called *int) {
1666				config.ClientAuth = RequireAnyClientCert
1667				config.InsecureSkipVerify = true
1668				config.VerifyConnection = func(c ConnectionState) error {
1669					*called++
1670					if l := len(c.PeerCertificates); l != 1 {
1671						return fmt.Errorf("server: got len(PeerCertificates) = %d, wanted 1", l)
1672					}
1673					if c.VerifiedChains != nil {
1674						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1675					}
1676					return checkFields(c, called, "server")
1677				}
1678			},
1679			configureClient: func(config *Config, called *int) {
1680				config.InsecureSkipVerify = true
1681				config.VerifyConnection = func(c ConnectionState) error {
1682					*called++
1683					if l := len(c.PeerCertificates); l != 1 {
1684						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1685					}
1686					if c.VerifiedChains != nil {
1687						return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains)
1688					}
1689					if c.DidResume {
1690						return nil
1691						// The SCTs and OCSP Responce are dropped on resumption.
1692						// See http://golang.org/issue/39075.
1693					}
1694					if len(c.OCSPResponse) == 0 {
1695						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1696					}
1697					if len(c.SignedCertificateTimestamps) == 0 {
1698						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1699					}
1700					return checkFields(c, called, "client")
1701				}
1702			},
1703		},
1704		{
1705			name: "NoClientCert",
1706			configureServer: func(config *Config, called *int) {
1707				config.ClientAuth = NoClientCert
1708				config.VerifyConnection = func(c ConnectionState) error {
1709					*called++
1710					return checkFields(c, called, "server")
1711				}
1712			},
1713			configureClient: func(config *Config, called *int) {
1714				config.VerifyConnection = func(c ConnectionState) error {
1715					*called++
1716					return checkFields(c, called, "client")
1717				}
1718			},
1719		},
1720		{
1721			name: "RequestClientCert",
1722			configureServer: func(config *Config, called *int) {
1723				config.ClientAuth = RequestClientCert
1724				config.VerifyConnection = func(c ConnectionState) error {
1725					*called++
1726					return checkFields(c, called, "server")
1727				}
1728			},
1729			configureClient: func(config *Config, called *int) {
1730				config.Certificates = nil // clear the client cert
1731				config.VerifyConnection = func(c ConnectionState) error {
1732					*called++
1733					if l := len(c.PeerCertificates); l != 1 {
1734						return fmt.Errorf("client: got len(PeerCertificates) = %d, wanted 1", l)
1735					}
1736					if len(c.VerifiedChains) == 0 {
1737						return fmt.Errorf("client: got len(VerifiedChains) = 0, wanted non-zero")
1738					}
1739					if c.DidResume {
1740						return nil
1741						// The SCTs and OCSP Responce are dropped on resumption.
1742						// See http://golang.org/issue/39075.
1743					}
1744					if len(c.OCSPResponse) == 0 {
1745						return fmt.Errorf("client: got len(OCSPResponse) = 0, wanted non-zero")
1746					}
1747					if len(c.SignedCertificateTimestamps) == 0 {
1748						return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero")
1749					}
1750					return checkFields(c, called, "client")
1751				}
1752			},
1753		},
1754	}
1755	for _, test := range tests {
1756		issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1757		if err != nil {
1758			panic(err)
1759		}
1760		rootCAs := x509.NewCertPool()
1761		rootCAs.AddCert(issuer)
1762
1763		var serverCalled, clientCalled int
1764
1765		serverConfig := &Config{
1766			MaxVersion:   version,
1767			Certificates: []Certificate{testConfig.Certificates[0]},
1768			ClientCAs:    rootCAs,
1769			NextProtos:   []string{"protocol1"},
1770		}
1771		serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
1772		serverConfig.Certificates[0].OCSPStaple = []byte("dummy ocsp")
1773		test.configureServer(serverConfig, &serverCalled)
1774
1775		clientConfig := &Config{
1776			MaxVersion:         version,
1777			ClientSessionCache: NewLRUClientSessionCache(32),
1778			RootCAs:            rootCAs,
1779			ServerName:         "example.golang",
1780			Certificates:       []Certificate{testConfig.Certificates[0]},
1781			NextProtos:         []string{"protocol1"},
1782		}
1783		test.configureClient(clientConfig, &clientCalled)
1784
1785		testHandshakeState := func(name string, didResume bool) {
1786			_, hs, err := testHandshake(t, clientConfig, serverConfig)
1787			if err != nil {
1788				t.Fatalf("%s: handshake failed: %s", name, err)
1789			}
1790			if hs.DidResume != didResume {
1791				t.Errorf("%s: resumed: %v, expected: %v", name, hs.DidResume, didResume)
1792			}
1793			wantCalled := 1
1794			if didResume {
1795				wantCalled = 2 // resumption would mean this is the second time it was called in this test
1796			}
1797			if clientCalled != wantCalled {
1798				t.Errorf("%s: expected client VerifyConnection called %d times, did %d times", name, wantCalled, clientCalled)
1799			}
1800			if serverCalled != wantCalled {
1801				t.Errorf("%s: expected server VerifyConnection called %d times, did %d times", name, wantCalled, serverCalled)
1802			}
1803		}
1804		testHandshakeState(fmt.Sprintf("%s-FullHandshake", test.name), false)
1805		testHandshakeState(fmt.Sprintf("%s-Resumption", test.name), true)
1806	}
1807}
1808
1809func TestVerifyPeerCertificate(t *testing.T) {
1810	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1811	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1812}
1813
1814func testVerifyPeerCertificate(t *testing.T, version uint16) {
1815	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1816	if err != nil {
1817		panic(err)
1818	}
1819
1820	rootCAs := x509.NewCertPool()
1821	rootCAs.AddCert(issuer)
1822
1823	now := func() time.Time { return time.Unix(1476984729, 0) }
1824
1825	sentinelErr := errors.New("TestVerifyPeerCertificate")
1826
1827	verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1828		if l := len(rawCerts); l != 1 {
1829			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1830		}
1831		if len(validatedChains) == 0 {
1832			return errors.New("got len(validatedChains) = 0, wanted non-zero")
1833		}
1834		*called = true
1835		return nil
1836	}
1837	verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error {
1838		if l := len(c.PeerCertificates); l != 1 {
1839			return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l)
1840		}
1841		if len(c.VerifiedChains) == 0 {
1842			return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero")
1843		}
1844		if isClient && len(c.OCSPResponse) == 0 {
1845			return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero")
1846		}
1847		*called = true
1848		return nil
1849	}
1850
1851	tests := []struct {
1852		configureServer func(*Config, *bool)
1853		configureClient func(*Config, *bool)
1854		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1855	}{
1856		{
1857			configureServer: func(config *Config, called *bool) {
1858				config.InsecureSkipVerify = false
1859				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1860					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1861				}
1862			},
1863			configureClient: func(config *Config, called *bool) {
1864				config.InsecureSkipVerify = false
1865				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1866					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
1867				}
1868			},
1869			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1870				if clientErr != nil {
1871					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1872				}
1873				if serverErr != nil {
1874					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1875				}
1876				if !clientCalled {
1877					t.Errorf("test[%d]: client did not call callback", testNo)
1878				}
1879				if !serverCalled {
1880					t.Errorf("test[%d]: server did not call callback", testNo)
1881				}
1882			},
1883		},
1884		{
1885			configureServer: func(config *Config, called *bool) {
1886				config.InsecureSkipVerify = false
1887				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1888					return sentinelErr
1889				}
1890			},
1891			configureClient: func(config *Config, called *bool) {
1892				config.VerifyPeerCertificate = nil
1893			},
1894			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1895				if serverErr != sentinelErr {
1896					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1897				}
1898			},
1899		},
1900		{
1901			configureServer: func(config *Config, called *bool) {
1902				config.InsecureSkipVerify = false
1903			},
1904			configureClient: func(config *Config, called *bool) {
1905				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1906					return sentinelErr
1907				}
1908			},
1909			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1910				if clientErr != sentinelErr {
1911					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1912				}
1913			},
1914		},
1915		{
1916			configureServer: func(config *Config, called *bool) {
1917				config.InsecureSkipVerify = false
1918			},
1919			configureClient: func(config *Config, called *bool) {
1920				config.InsecureSkipVerify = true
1921				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1922					if l := len(rawCerts); l != 1 {
1923						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1924					}
1925					// With InsecureSkipVerify set, this
1926					// callback should still be called but
1927					// validatedChains must be empty.
1928					if l := len(validatedChains); l != 0 {
1929						return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1930					}
1931					*called = true
1932					return nil
1933				}
1934			},
1935			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1936				if clientErr != nil {
1937					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1938				}
1939				if serverErr != nil {
1940					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1941				}
1942				if !clientCalled {
1943					t.Errorf("test[%d]: client did not call callback", testNo)
1944				}
1945			},
1946		},
1947		{
1948			configureServer: func(config *Config, called *bool) {
1949				config.InsecureSkipVerify = false
1950				config.VerifyConnection = func(c ConnectionState) error {
1951					return verifyConnectionCallback(called, false, c)
1952				}
1953			},
1954			configureClient: func(config *Config, called *bool) {
1955				config.InsecureSkipVerify = false
1956				config.VerifyConnection = func(c ConnectionState) error {
1957					return verifyConnectionCallback(called, true, c)
1958				}
1959			},
1960			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1961				if clientErr != nil {
1962					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1963				}
1964				if serverErr != nil {
1965					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1966				}
1967				if !clientCalled {
1968					t.Errorf("test[%d]: client did not call callback", testNo)
1969				}
1970				if !serverCalled {
1971					t.Errorf("test[%d]: server did not call callback", testNo)
1972				}
1973			},
1974		},
1975		{
1976			configureServer: func(config *Config, called *bool) {
1977				config.InsecureSkipVerify = false
1978				config.VerifyConnection = func(c ConnectionState) error {
1979					return sentinelErr
1980				}
1981			},
1982			configureClient: func(config *Config, called *bool) {
1983				config.InsecureSkipVerify = false
1984				config.VerifyConnection = nil
1985			},
1986			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1987				if serverErr != sentinelErr {
1988					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1989				}
1990			},
1991		},
1992		{
1993			configureServer: func(config *Config, called *bool) {
1994				config.InsecureSkipVerify = false
1995				config.VerifyConnection = nil
1996			},
1997			configureClient: func(config *Config, called *bool) {
1998				config.InsecureSkipVerify = false
1999				config.VerifyConnection = func(c ConnectionState) error {
2000					return sentinelErr
2001				}
2002			},
2003			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2004				if clientErr != sentinelErr {
2005					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2006				}
2007			},
2008		},
2009		{
2010			configureServer: func(config *Config, called *bool) {
2011				config.InsecureSkipVerify = false
2012				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
2013					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
2014				}
2015				config.VerifyConnection = func(c ConnectionState) error {
2016					return sentinelErr
2017				}
2018			},
2019			configureClient: func(config *Config, called *bool) {
2020				config.InsecureSkipVerify = false
2021				config.VerifyPeerCertificate = nil
2022				config.VerifyConnection = nil
2023			},
2024			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2025				if serverErr != sentinelErr {
2026					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
2027				}
2028				if !serverCalled {
2029					t.Errorf("test[%d]: server did not call callback", testNo)
2030				}
2031			},
2032		},
2033		{
2034			configureServer: func(config *Config, called *bool) {
2035				config.InsecureSkipVerify = false
2036				config.VerifyPeerCertificate = nil
2037				config.VerifyConnection = nil
2038			},
2039			configureClient: func(config *Config, called *bool) {
2040				config.InsecureSkipVerify = false
2041				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
2042					return verifyPeerCertificateCallback(called, rawCerts, validatedChains)
2043				}
2044				config.VerifyConnection = func(c ConnectionState) error {
2045					return sentinelErr
2046				}
2047			},
2048			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
2049				if clientErr != sentinelErr {
2050					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
2051				}
2052				if !clientCalled {
2053					t.Errorf("test[%d]: client did not call callback", testNo)
2054				}
2055			},
2056		},
2057	}
2058
2059	for i, test := range tests {
2060		c, s := localPipe(t)
2061		done := make(chan error)
2062
2063		var clientCalled, serverCalled bool
2064
2065		go func() {
2066			config := testConfig.Clone()
2067			config.ServerName = "example.golang"
2068			config.ClientAuth = RequireAndVerifyClientCert
2069			config.ClientCAs = rootCAs
2070			config.Time = now
2071			config.MaxVersion = version
2072			config.Certificates = make([]Certificate, 1)
2073			config.Certificates[0].Certificate = [][]byte{testRSACertificate}
2074			config.Certificates[0].PrivateKey = testRSAPrivateKey
2075			config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")}
2076			config.Certificates[0].OCSPStaple = []byte("dummy ocsp")
2077			test.configureServer(config, &serverCalled)
2078
2079			err = Server(s, config, nil).Handshake()
2080			s.Close()
2081			done <- err
2082		}()
2083
2084		config := testConfig.Clone()
2085		config.ServerName = "example.golang"
2086		config.RootCAs = rootCAs
2087		config.Time = now
2088		config.MaxVersion = version
2089		test.configureClient(config, &clientCalled)
2090		clientErr := Client(c, config, nil).Handshake()
2091		c.Close()
2092		serverErr := <-done
2093
2094		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
2095	}
2096}
2097
2098// brokenConn wraps a net.Conn and causes all Writes after a certain number to
2099// fail with brokenConnErr.
2100type brokenConn struct {
2101	net.Conn
2102
2103	// breakAfter is the number of successful writes that will be allowed
2104	// before all subsequent writes fail.
2105	breakAfter int
2106
2107	// numWrites is the number of writes that have been done.
2108	numWrites int
2109}
2110
2111// brokenConnErr is the error that brokenConn returns once exhausted.
2112var brokenConnErr = errors.New("too many writes to brokenConn")
2113
2114func (b *brokenConn) Write(data []byte) (int, error) {
2115	if b.numWrites >= b.breakAfter {
2116		return 0, brokenConnErr
2117	}
2118
2119	b.numWrites++
2120	return b.Conn.Write(data)
2121}
2122
2123func TestFailedWrite(t *testing.T) {
2124	// Test that a write error during the handshake is returned.
2125	for _, breakAfter := range []int{0, 1} {
2126		c, s := localPipe(t)
2127		done := make(chan bool)
2128
2129		go func() {
2130			Server(s, testConfig, nil).Handshake()
2131			s.Close()
2132			done <- true
2133		}()
2134
2135		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
2136		err := Client(brokenC, testConfig, nil).Handshake()
2137		if err != brokenConnErr {
2138			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
2139		}
2140		brokenC.Close()
2141
2142		<-done
2143	}
2144}
2145
2146// writeCountingConn wraps a net.Conn and counts the number of Write calls.
2147type writeCountingConn struct {
2148	net.Conn
2149
2150	// numWrites is the number of writes that have been done.
2151	numWrites int
2152}
2153
2154func (wcc *writeCountingConn) Write(data []byte) (int, error) {
2155	wcc.numWrites++
2156	return wcc.Conn.Write(data)
2157}
2158
2159func TestBuffering(t *testing.T) {
2160	t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
2161	t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
2162}
2163
2164func testBuffering(t *testing.T, version uint16) {
2165	c, s := localPipe(t)
2166	done := make(chan bool)
2167
2168	clientWCC := &writeCountingConn{Conn: c}
2169	serverWCC := &writeCountingConn{Conn: s}
2170
2171	go func() {
2172		config := testConfig.Clone()
2173		config.MaxVersion = version
2174		Server(serverWCC, config, nil).Handshake()
2175		serverWCC.Close()
2176		done <- true
2177	}()
2178
2179	err := Client(clientWCC, testConfig, nil).Handshake()
2180	if err != nil {
2181		t.Fatal(err)
2182	}
2183	clientWCC.Close()
2184	<-done
2185
2186	var expectedClient, expectedServer int
2187	if version == VersionTLS13 {
2188		expectedClient = 2
2189		expectedServer = 1
2190	} else {
2191		expectedClient = 2
2192		expectedServer = 2
2193	}
2194
2195	if n := clientWCC.numWrites; n != expectedClient {
2196		t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
2197	}
2198
2199	if n := serverWCC.numWrites; n != expectedServer {
2200		t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
2201	}
2202}
2203
2204func TestAlertFlushing(t *testing.T) {
2205	c, s := localPipe(t)
2206	done := make(chan bool)
2207
2208	clientWCC := &writeCountingConn{Conn: c}
2209	serverWCC := &writeCountingConn{Conn: s}
2210
2211	serverConfig := testConfig.Clone()
2212
2213	// Cause a signature-time error
2214	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
2215	brokenKey.D = big.NewInt(42)
2216	serverConfig.Certificates = []Certificate{{
2217		Certificate: [][]byte{testRSACertificate},
2218		PrivateKey:  &brokenKey,
2219	}}
2220
2221	go func() {
2222		Server(serverWCC, serverConfig, nil).Handshake()
2223		serverWCC.Close()
2224		done <- true
2225	}()
2226
2227	err := Client(clientWCC, testConfig, nil).Handshake()
2228	if err == nil {
2229		t.Fatal("client unexpectedly returned no error")
2230	}
2231
2232	const expectedError = "remote error: tls: internal error"
2233	if e := err.Error(); !strings.Contains(e, expectedError) {
2234		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
2235	}
2236	clientWCC.Close()
2237	<-done
2238
2239	if n := serverWCC.numWrites; n != 1 {
2240		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
2241	}
2242}
2243
2244func TestHandshakeRace(t *testing.T) {
2245	if testing.Short() {
2246		t.Skip("skipping in -short mode")
2247	}
2248	t.Parallel()
2249	// This test races a Read and Write to try and complete a handshake in
2250	// order to provide some evidence that there are no races or deadlocks
2251	// in the handshake locking.
2252	for i := 0; i < 32; i++ {
2253		c, s := localPipe(t)
2254
2255		go func() {
2256			server := Server(s, testConfig, nil)
2257			if err := server.Handshake(); err != nil {
2258				panic(err)
2259			}
2260
2261			var request [1]byte
2262			if n, err := server.Read(request[:]); err != nil || n != 1 {
2263				panic(err)
2264			}
2265
2266			server.Write(request[:])
2267			server.Close()
2268		}()
2269
2270		startWrite := make(chan struct{})
2271		startRead := make(chan struct{})
2272		readDone := make(chan struct{}, 1)
2273
2274		client := Client(c, testConfig, nil)
2275		go func() {
2276			<-startWrite
2277			var request [1]byte
2278			client.Write(request[:])
2279		}()
2280
2281		go func() {
2282			<-startRead
2283			var reply [1]byte
2284			if _, err := io.ReadFull(client, reply[:]); err != nil {
2285				panic(err)
2286			}
2287			c.Close()
2288			readDone <- struct{}{}
2289		}()
2290
2291		if i&1 == 1 {
2292			startWrite <- struct{}{}
2293			startRead <- struct{}{}
2294		} else {
2295			startRead <- struct{}{}
2296			startWrite <- struct{}{}
2297		}
2298		<-readDone
2299	}
2300}
2301
2302var getClientCertificateTests = []struct {
2303	setup               func(*Config, *Config)
2304	expectedClientError string
2305	verify              func(*testing.T, int, *ConnectionState)
2306}{
2307	{
2308		func(clientConfig, serverConfig *Config) {
2309			// Returning a Certificate with no certificate data
2310			// should result in an empty message being sent to the
2311			// server.
2312			serverConfig.ClientCAs = nil
2313			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2314				if len(cri.SignatureSchemes) == 0 {
2315					panic("empty SignatureSchemes")
2316				}
2317				if len(cri.AcceptableCAs) != 0 {
2318					panic("AcceptableCAs should have been empty")
2319				}
2320				return new(Certificate), nil
2321			}
2322		},
2323		"",
2324		func(t *testing.T, testNum int, cs *ConnectionState) {
2325			if l := len(cs.PeerCertificates); l != 0 {
2326				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2327			}
2328		},
2329	},
2330	{
2331		func(clientConfig, serverConfig *Config) {
2332			// With TLS 1.1, the SignatureSchemes should be
2333			// synthesised from the supported certificate types.
2334			clientConfig.MaxVersion = VersionTLS11
2335			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2336				if len(cri.SignatureSchemes) == 0 {
2337					panic("empty SignatureSchemes")
2338				}
2339				return new(Certificate), nil
2340			}
2341		},
2342		"",
2343		func(t *testing.T, testNum int, cs *ConnectionState) {
2344			if l := len(cs.PeerCertificates); l != 0 {
2345				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
2346			}
2347		},
2348	},
2349	{
2350		func(clientConfig, serverConfig *Config) {
2351			// Returning an error should abort the handshake with
2352			// that error.
2353			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2354				return nil, errors.New("GetClientCertificate")
2355			}
2356		},
2357		"GetClientCertificate",
2358		func(t *testing.T, testNum int, cs *ConnectionState) {
2359		},
2360	},
2361	{
2362		func(clientConfig, serverConfig *Config) {
2363			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
2364				if len(cri.AcceptableCAs) == 0 {
2365					panic("empty AcceptableCAs")
2366				}
2367				cert := &Certificate{
2368					Certificate: [][]byte{testRSACertificate},
2369					PrivateKey:  testRSAPrivateKey,
2370				}
2371				return cert, nil
2372			}
2373		},
2374		"",
2375		func(t *testing.T, testNum int, cs *ConnectionState) {
2376			if len(cs.VerifiedChains) == 0 {
2377				t.Errorf("#%d: expected some verified chains, but found none", testNum)
2378			}
2379		},
2380	},
2381}
2382
2383func TestGetClientCertificate(t *testing.T) {
2384	t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
2385	t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
2386}
2387
2388func testGetClientCertificate(t *testing.T, version uint16) {
2389	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2390	if err != nil {
2391		panic(err)
2392	}
2393
2394	for i, test := range getClientCertificateTests {
2395		serverConfig := testConfig.Clone()
2396		serverConfig.ClientAuth = VerifyClientCertIfGiven
2397		serverConfig.RootCAs = x509.NewCertPool()
2398		serverConfig.RootCAs.AddCert(issuer)
2399		serverConfig.ClientCAs = serverConfig.RootCAs
2400		serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
2401		serverConfig.MaxVersion = version
2402
2403		clientConfig := testConfig.Clone()
2404		clientConfig.MaxVersion = version
2405
2406		test.setup(clientConfig, serverConfig)
2407
2408		type serverResult struct {
2409			cs  ConnectionState
2410			err error
2411		}
2412
2413		c, s := localPipe(t)
2414		done := make(chan serverResult)
2415
2416		go func() {
2417			defer s.Close()
2418			server := Server(s, serverConfig, nil)
2419			err := server.Handshake()
2420
2421			var cs ConnectionState
2422			if err == nil {
2423				cs = server.ConnectionState()
2424			}
2425			done <- serverResult{cs, err}
2426		}()
2427
2428		clientErr := Client(c, clientConfig, nil).Handshake()
2429		c.Close()
2430
2431		result := <-done
2432
2433		if clientErr != nil {
2434			if len(test.expectedClientError) == 0 {
2435				t.Errorf("#%d: client error: %v", i, clientErr)
2436			} else if got := clientErr.Error(); got != test.expectedClientError {
2437				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
2438			} else {
2439				test.verify(t, i, &result.cs)
2440			}
2441		} else if len(test.expectedClientError) > 0 {
2442			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
2443		} else if err := result.err; err != nil {
2444			t.Errorf("#%d: server error: %v", i, err)
2445		} else {
2446			test.verify(t, i, &result.cs)
2447		}
2448	}
2449}
2450
2451func TestRSAPSSKeyError(t *testing.T) {
2452	// crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
2453	// public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
2454	// the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
2455	// parse, or that they don't carry *rsa.PublicKey keys.
2456	b, _ := pem.Decode([]byte(`
2457-----BEGIN CERTIFICATE-----
2458MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
2459MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
2460AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
2461MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
2462ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
2463/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
2464b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
2465QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
2466czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
2467JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
2468AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
2469OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
2470AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
2471sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
2472H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
2473KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
2474bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
2475HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
2476RwBA9Xk1KBNF
2477-----END CERTIFICATE-----`))
2478	if b == nil {
2479		t.Fatal("Failed to decode certificate")
2480	}
2481	cert, err := x509.ParseCertificate(b.Bytes)
2482	if err != nil {
2483		return
2484	}
2485	if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
2486		t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
2487	}
2488}
2489
2490func TestCloseClientConnectionOnIdleServer(t *testing.T) {
2491	clientConn, serverConn := localPipe(t)
2492	client := Client(clientConn, testConfig.Clone(), nil)
2493	go func() {
2494		var b [1]byte
2495		serverConn.Read(b[:])
2496		client.Close()
2497	}()
2498	client.SetWriteDeadline(time.Now().Add(time.Minute))
2499	err := client.Handshake()
2500	if err != nil {
2501		if err, ok := err.(net.Error); ok && err.Timeout() {
2502			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
2503		}
2504	} else {
2505		t.Errorf("Error expected, but no error returned")
2506	}
2507}
2508
2509func testDowngradeCanary(t *testing.T, clientVersion, serverVersion uint16) error {
2510	defer func() { testingOnlyForceDowngradeCanary = false }()
2511	testingOnlyForceDowngradeCanary = true
2512
2513	clientConfig := testConfig.Clone()
2514	clientConfig.MaxVersion = clientVersion
2515	serverConfig := testConfig.Clone()
2516	serverConfig.MaxVersion = serverVersion
2517	_, _, err := testHandshake(t, clientConfig, serverConfig)
2518	return err
2519}
2520
2521func TestDowngradeCanary(t *testing.T) {
2522	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS12); err == nil {
2523		t.Errorf("downgrade from TLS 1.3 to TLS 1.2 was not detected")
2524	}
2525	if testing.Short() {
2526		t.Skip("skipping the rest of the checks in short mode")
2527	}
2528	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS11); err == nil {
2529		t.Errorf("downgrade from TLS 1.3 to TLS 1.1 was not detected")
2530	}
2531	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS10); err == nil {
2532		t.Errorf("downgrade from TLS 1.3 to TLS 1.0 was not detected")
2533	}
2534	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS11); err == nil {
2535		t.Errorf("downgrade from TLS 1.2 to TLS 1.1 was not detected")
2536	}
2537	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS10); err == nil {
2538		t.Errorf("downgrade from TLS 1.2 to TLS 1.0 was not detected")
2539	}
2540	if err := testDowngradeCanary(t, VersionTLS13, VersionTLS13); err != nil {
2541		t.Errorf("server unexpectedly sent downgrade canary for TLS 1.3")
2542	}
2543	if err := testDowngradeCanary(t, VersionTLS12, VersionTLS12); err != nil {
2544		t.Errorf("client didn't ignore expected TLS 1.2 canary")
2545	}
2546	if err := testDowngradeCanary(t, VersionTLS11, VersionTLS11); err != nil {
2547		t.Errorf("client unexpectedly reacted to a canary in TLS 1.1")
2548	}
2549	if err := testDowngradeCanary(t, VersionTLS10, VersionTLS10); err != nil {
2550		t.Errorf("client unexpectedly reacted to a canary in TLS 1.0")
2551	}
2552}
2553
2554func TestResumptionKeepsOCSPAndSCT(t *testing.T) {
2555	t.Run("TLSv12", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS12) })
2556	t.Run("TLSv13", func(t *testing.T) { testResumptionKeepsOCSPAndSCT(t, VersionTLS13) })
2557}
2558
2559func testResumptionKeepsOCSPAndSCT(t *testing.T, ver uint16) {
2560	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
2561	if err != nil {
2562		t.Fatalf("failed to parse test issuer")
2563	}
2564	roots := x509.NewCertPool()
2565	roots.AddCert(issuer)
2566	clientConfig := &Config{
2567		MaxVersion:         ver,
2568		ClientSessionCache: NewLRUClientSessionCache(32),
2569		ServerName:         "example.golang",
2570		RootCAs:            roots,
2571	}
2572	serverConfig := testConfig.Clone()
2573	serverConfig.MaxVersion = ver
2574	serverConfig.Certificates[0].OCSPStaple = []byte{1, 2, 3}
2575	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{4, 5, 6}}
2576
2577	_, ccs, err := testHandshake(t, clientConfig, serverConfig)
2578	if err != nil {
2579		t.Fatalf("handshake failed: %s", err)
2580	}
2581	// after a new session we expect to see OCSPResponse and
2582	// SignedCertificateTimestamps populated as usual
2583	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2584		t.Errorf("client ConnectionState contained unexpected OCSPResponse: wanted %v, got %v",
2585			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2586	}
2587	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2588		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps: wanted %v, got %v",
2589			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2590	}
2591
2592	// if the server doesn't send any SCTs, repopulate the old SCTs
2593	oldSCTs := serverConfig.Certificates[0].SignedCertificateTimestamps
2594	serverConfig.Certificates[0].SignedCertificateTimestamps = nil
2595	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
2596	if err != nil {
2597		t.Fatalf("handshake failed: %s", err)
2598	}
2599	if !ccs.DidResume {
2600		t.Fatalf("expected session to be resumed")
2601	}
2602	// after a resumed session we also expect to see OCSPResponse
2603	// and SignedCertificateTimestamps populated
2604	if !bytes.Equal(ccs.OCSPResponse, serverConfig.Certificates[0].OCSPStaple) {
2605		t.Errorf("client ConnectionState contained unexpected OCSPResponse after resumption: wanted %v, got %v",
2606			serverConfig.Certificates[0].OCSPStaple, ccs.OCSPResponse)
2607	}
2608	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, oldSCTs) {
2609		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2610			oldSCTs, ccs.SignedCertificateTimestamps)
2611	}
2612
2613	//  Only test overriding the SCTs for TLS 1.2, since in 1.3
2614	// the server won't send the message containing them
2615	if ver == VersionTLS13 {
2616		return
2617	}
2618
2619	// if the server changes the SCTs it sends, they should override the saved SCTs
2620	serverConfig.Certificates[0].SignedCertificateTimestamps = [][]byte{{7, 8, 9}}
2621	_, ccs, err = testHandshake(t, clientConfig, serverConfig)
2622	if err != nil {
2623		t.Fatalf("handshake failed: %s", err)
2624	}
2625	if !ccs.DidResume {
2626		t.Fatalf("expected session to be resumed")
2627	}
2628	if !reflect.DeepEqual(ccs.SignedCertificateTimestamps, serverConfig.Certificates[0].SignedCertificateTimestamps) {
2629		t.Errorf("client ConnectionState contained unexpected SignedCertificateTimestamps after resumption: wanted %v, got %v",
2630			serverConfig.Certificates[0].SignedCertificateTimestamps, ccs.SignedCertificateTimestamps)
2631	}
2632}
2633