1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"crypto/ecdsa"
10	"crypto/rsa"
11	"crypto/x509"
12	"encoding/base64"
13	"encoding/binary"
14	"encoding/pem"
15	"errors"
16	"fmt"
17	"io"
18	"math/big"
19	"net"
20	"os"
21	"os/exec"
22	"path/filepath"
23	"strconv"
24	"strings"
25	"testing"
26	"time"
27)
28
29func init() {
30	// TLS 1.3 cipher suites preferences are not configurable and change based
31	// on the architecture. Force them to the version with AES acceleration for
32	// test consistency.
33	once.Do(initDefaultCipherSuites)
34	varDefaultCipherSuitesTLS13 = []uint16{
35		TLS_AES_128_GCM_SHA256,
36		TLS_CHACHA20_POLY1305_SHA256,
37		TLS_AES_256_GCM_SHA384,
38	}
39}
40
41// Note: see comment in handshake_test.go for details of how the reference
42// tests work.
43
44// opensslInputEvent enumerates possible inputs that can be sent to an `openssl
45// s_client` process.
46type opensslInputEvent int
47
48const (
49	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
50	// connection.
51	opensslRenegotiate opensslInputEvent = iota
52
53	// opensslSendBanner causes OpenSSL to send the contents of
54	// opensslSentinel on the connection.
55	opensslSendSentinel
56
57	// opensslKeyUpdate causes OpenSSL to send send a key update message to the
58	// client and request one back.
59	opensslKeyUpdate
60)
61
62const opensslSentinel = "SENTINEL\n"
63
64type opensslInput chan opensslInputEvent
65
66func (i opensslInput) Read(buf []byte) (n int, err error) {
67	for event := range i {
68		switch event {
69		case opensslRenegotiate:
70			return copy(buf, []byte("R\n")), nil
71		case opensslKeyUpdate:
72			return copy(buf, []byte("K\n")), nil
73		case opensslSendSentinel:
74			return copy(buf, []byte(opensslSentinel)), nil
75		default:
76			panic("unknown event")
77		}
78	}
79
80	return 0, io.EOF
81}
82
83// opensslOutputSink is an io.Writer that receives the stdout and stderr from an
84// `openssl` process and sends a value to handshakeComplete or readKeyUpdate
85// when certain messages are seen.
86type opensslOutputSink struct {
87	handshakeComplete chan struct{}
88	readKeyUpdate     chan struct{}
89	all               []byte
90	line              []byte
91}
92
93func newOpensslOutputSink() *opensslOutputSink {
94	return &opensslOutputSink{make(chan struct{}), make(chan struct{}), nil, nil}
95}
96
97// opensslEndOfHandshake is a message that the “openssl s_server” tool will
98// print when a handshake completes if run with “-state”.
99const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
100
101// opensslReadKeyUpdate is a message that the “openssl s_server” tool will
102// print when a KeyUpdate message is received if run with “-state”.
103const opensslReadKeyUpdate = "SSL_accept:TLSv1.3 read client key update"
104
105func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
106	o.line = append(o.line, data...)
107	o.all = append(o.all, data...)
108
109	for {
110		i := bytes.IndexByte(o.line, '\n')
111		if i < 0 {
112			break
113		}
114
115		if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
116			o.handshakeComplete <- struct{}{}
117		}
118		if bytes.Equal([]byte(opensslReadKeyUpdate), o.line[:i]) {
119			o.readKeyUpdate <- struct{}{}
120		}
121		o.line = o.line[i+1:]
122	}
123
124	return len(data), nil
125}
126
127func (o *opensslOutputSink) String() string {
128	return string(o.all)
129}
130
131// clientTest represents a test of the TLS client handshake against a reference
132// implementation.
133type clientTest struct {
134	// name is a freeform string identifying the test and the file in which
135	// the expected results will be stored.
136	name string
137	// args, if not empty, contains a series of arguments for the
138	// command to run for the reference server.
139	args []string
140	// config, if not nil, contains a custom Config to use for this test.
141	config *Config
142	// cert, if not empty, contains a DER-encoded certificate for the
143	// reference server.
144	cert []byte
145	// key, if not nil, contains either a *rsa.PrivateKey or
146	// *ecdsa.PrivateKey which is the private key for the reference server.
147	key interface{}
148	// extensions, if not nil, contains a list of extension data to be returned
149	// from the ServerHello. The data should be in standard TLS format with
150	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
151	extensions [][]byte
152	// validate, if not nil, is a function that will be called with the
153	// ConnectionState of the resulting connection. It returns a non-nil
154	// error if the ConnectionState is unacceptable.
155	validate func(ConnectionState) error
156	// numRenegotiations is the number of times that the connection will be
157	// renegotiated.
158	numRenegotiations int
159	// renegotiationExpectedToFail, if not zero, is the number of the
160	// renegotiation attempt that is expected to fail.
161	renegotiationExpectedToFail int
162	// checkRenegotiationError, if not nil, is called with any error
163	// arising from renegotiation. It can map expected errors to nil to
164	// ignore them.
165	checkRenegotiationError func(renegotiationNum int, err error) error
166	// sendKeyUpdate will cause the server to send a KeyUpdate message.
167	sendKeyUpdate bool
168}
169
170var serverCommand = []string{"openssl", "s_server", "-no_ticket", "-num_tickets", "0"}
171
172// connFromCommand starts the reference server process, connects to it and
173// returns a recordingConn for the connection. The stdin return value is an
174// opensslInput for the stdin of the child process. It must be closed before
175// Waiting for child.
176func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
177	cert := testRSACertificate
178	if len(test.cert) > 0 {
179		cert = test.cert
180	}
181	certPath := tempFile(string(cert))
182	defer os.Remove(certPath)
183
184	var key interface{} = testRSAPrivateKey
185	if test.key != nil {
186		key = test.key
187	}
188	var pemType string
189	var derBytes []byte
190	switch key := key.(type) {
191	case *rsa.PrivateKey:
192		pemType = "RSA"
193		derBytes = x509.MarshalPKCS1PrivateKey(key)
194	case *ecdsa.PrivateKey:
195		pemType = "EC"
196		var err error
197		derBytes, err = x509.MarshalECPrivateKey(key)
198		if err != nil {
199			panic(err)
200		}
201	default:
202		panic("unknown key type")
203	}
204
205	var pemOut bytes.Buffer
206	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
207
208	keyPath := tempFile(pemOut.String())
209	defer os.Remove(keyPath)
210
211	var command []string
212	command = append(command, serverCommand...)
213	command = append(command, test.args...)
214	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
215	// serverPort contains the port that OpenSSL will listen on. OpenSSL
216	// can't take "0" as an argument here so we have to pick a number and
217	// hope that it's not in use on the machine. Since this only occurs
218	// when -update is given and thus when there's a human watching the
219	// test, this isn't too bad.
220	const serverPort = 24323
221	command = append(command, "-accept", strconv.Itoa(serverPort))
222
223	if len(test.extensions) > 0 {
224		var serverInfo bytes.Buffer
225		for _, ext := range test.extensions {
226			pem.Encode(&serverInfo, &pem.Block{
227				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
228				Bytes: ext,
229			})
230		}
231		serverInfoPath := tempFile(serverInfo.String())
232		defer os.Remove(serverInfoPath)
233		command = append(command, "-serverinfo", serverInfoPath)
234	}
235
236	if test.numRenegotiations > 0 || test.sendKeyUpdate {
237		found := false
238		for _, flag := range command[1:] {
239			if flag == "-state" {
240				found = true
241				break
242			}
243		}
244
245		if !found {
246			panic("-state flag missing to OpenSSL, you need this if testing renegotiation or KeyUpdate")
247		}
248	}
249
250	cmd := exec.Command(command[0], command[1:]...)
251	stdin = opensslInput(make(chan opensslInputEvent))
252	cmd.Stdin = stdin
253	out := newOpensslOutputSink()
254	cmd.Stdout = out
255	cmd.Stderr = out
256	if err := cmd.Start(); err != nil {
257		return nil, nil, nil, nil, err
258	}
259
260	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
261	// opening the listening socket, so we can't use that to wait until it
262	// has started listening. Thus we are forced to poll until we get a
263	// connection.
264	var tcpConn net.Conn
265	for i := uint(0); i < 5; i++ {
266		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
267			IP:   net.IPv4(127, 0, 0, 1),
268			Port: serverPort,
269		})
270		if err == nil {
271			break
272		}
273		time.Sleep((1 << i) * 5 * time.Millisecond)
274	}
275	if err != nil {
276		close(stdin)
277		cmd.Process.Kill()
278		err = fmt.Errorf("error connecting to the OpenSSL server: %v (%v)\n\n%s", err, cmd.Wait(), out)
279		return nil, nil, nil, nil, err
280	}
281
282	record := &recordingConn{
283		Conn: tcpConn,
284	}
285
286	return record, cmd, stdin, out, nil
287}
288
289func (test *clientTest) dataPath() string {
290	return filepath.Join("testdata", "Client-"+test.name)
291}
292
293func (test *clientTest) loadData() (flows [][]byte, err error) {
294	in, err := os.Open(test.dataPath())
295	if err != nil {
296		return nil, err
297	}
298	defer in.Close()
299	return parseTestData(in)
300}
301
302func (test *clientTest) run(t *testing.T, write bool) {
303	checkOpenSSLVersion(t)
304
305	var clientConn, serverConn net.Conn
306	var recordingConn *recordingConn
307	var childProcess *exec.Cmd
308	var stdin opensslInput
309	var stdout *opensslOutputSink
310
311	if write {
312		var err error
313		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
314		if err != nil {
315			t.Fatalf("Failed to start subcommand: %s", err)
316		}
317		clientConn = recordingConn
318		defer func() {
319			if t.Failed() {
320				t.Logf("OpenSSL output:\n\n%s", stdout.all)
321			}
322		}()
323	} else {
324		clientConn, serverConn = localPipe(t)
325	}
326
327	doneChan := make(chan bool)
328	defer func() {
329		clientConn.Close()
330		<-doneChan
331	}()
332	go func() {
333		defer close(doneChan)
334
335		config := test.config
336		if config == nil {
337			config = testConfig
338		}
339		client := Client(clientConn, config)
340		defer client.Close()
341
342		if _, err := client.Write([]byte("hello\n")); err != nil {
343			t.Errorf("Client.Write failed: %s", err)
344			return
345		}
346
347		for i := 1; i <= test.numRenegotiations; i++ {
348			// The initial handshake will generate a
349			// handshakeComplete signal which needs to be quashed.
350			if i == 1 && write {
351				<-stdout.handshakeComplete
352			}
353
354			// OpenSSL will try to interleave application data and
355			// a renegotiation if we send both concurrently.
356			// Therefore: ask OpensSSL to start a renegotiation, run
357			// a goroutine to call client.Read and thus process the
358			// renegotiation request, watch for OpenSSL's stdout to
359			// indicate that the handshake is complete and,
360			// finally, have OpenSSL write something to cause
361			// client.Read to complete.
362			if write {
363				stdin <- opensslRenegotiate
364			}
365
366			signalChan := make(chan struct{})
367
368			go func() {
369				defer close(signalChan)
370
371				buf := make([]byte, 256)
372				n, err := client.Read(buf)
373
374				if test.checkRenegotiationError != nil {
375					newErr := test.checkRenegotiationError(i, err)
376					if err != nil && newErr == nil {
377						return
378					}
379					err = newErr
380				}
381
382				if err != nil {
383					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
384					return
385				}
386
387				buf = buf[:n]
388				if !bytes.Equal([]byte(opensslSentinel), buf) {
389					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
390				}
391
392				if expected := i + 1; client.handshakes != expected {
393					t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
394				}
395			}()
396
397			if write && test.renegotiationExpectedToFail != i {
398				<-stdout.handshakeComplete
399				stdin <- opensslSendSentinel
400			}
401			<-signalChan
402		}
403
404		if test.sendKeyUpdate {
405			if write {
406				<-stdout.handshakeComplete
407				stdin <- opensslKeyUpdate
408			}
409
410			doneRead := make(chan struct{})
411
412			go func() {
413				defer close(doneRead)
414
415				buf := make([]byte, 256)
416				n, err := client.Read(buf)
417
418				if err != nil {
419					t.Errorf("Client.Read failed after KeyUpdate: %s", err)
420					return
421				}
422
423				buf = buf[:n]
424				if !bytes.Equal([]byte(opensslSentinel), buf) {
425					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
426				}
427			}()
428
429			if write {
430				// There's no real reason to wait for the client KeyUpdate to
431				// send data with the new server keys, except that s_server
432				// drops writes if they are sent at the wrong time.
433				<-stdout.readKeyUpdate
434				stdin <- opensslSendSentinel
435			}
436			<-doneRead
437
438			if _, err := client.Write([]byte("hello again\n")); err != nil {
439				t.Errorf("Client.Write failed: %s", err)
440				return
441			}
442		}
443
444		if test.validate != nil {
445			if err := test.validate(client.ConnectionState()); err != nil {
446				t.Errorf("validate callback returned error: %s", err)
447			}
448		}
449
450		// If the server sent us an alert after our last flight, give it a
451		// chance to arrive.
452		if write && test.renegotiationExpectedToFail == 0 {
453			if err := peekError(client); err != nil {
454				t.Errorf("final Read returned an error: %s", err)
455			}
456		}
457	}()
458
459	if !write {
460		flows, err := test.loadData()
461		if err != nil {
462			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
463		}
464		for i, b := range flows {
465			if i%2 == 1 {
466				serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
467				serverConn.Write(b)
468				continue
469			}
470			bb := make([]byte, len(b))
471			serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
472			_, err := io.ReadFull(serverConn, bb)
473			if err != nil {
474				t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
475			}
476			if !bytes.Equal(b, bb) {
477				t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
478			}
479		}
480	}
481
482	<-doneChan
483	if !write {
484		serverConn.Close()
485	}
486
487	if write {
488		path := test.dataPath()
489		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
490		if err != nil {
491			t.Fatalf("Failed to create output file: %s", err)
492		}
493		defer out.Close()
494		recordingConn.Close()
495		close(stdin)
496		childProcess.Process.Kill()
497		childProcess.Wait()
498		if len(recordingConn.flows) < 3 {
499			t.Fatalf("Client connection didn't work")
500		}
501		recordingConn.WriteTo(out)
502		t.Logf("Wrote %s\n", path)
503	}
504}
505
506// peekError does a read with a short timeout to check if the next read would
507// cause an error, for example if there is an alert waiting on the wire.
508func peekError(conn net.Conn) error {
509	conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
510	if n, err := conn.Read(make([]byte, 1)); n != 0 {
511		return errors.New("unexpectedly read data")
512	} else if err != nil {
513		if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
514			return err
515		}
516	}
517	return nil
518}
519
520func runClientTestForVersion(t *testing.T, template *clientTest, version, option string) {
521	t.Run(version, func(t *testing.T) {
522		// Make a deep copy of the template before going parallel.
523		test := *template
524		if template.config != nil {
525			test.config = template.config.Clone()
526		}
527
528		if !*update {
529			t.Parallel()
530		}
531
532		test.name = version + "-" + test.name
533		test.args = append([]string{option}, test.args...)
534		test.run(t, *update)
535	})
536}
537
538func runClientTestTLS10(t *testing.T, template *clientTest) {
539	runClientTestForVersion(t, template, "TLSv10", "-tls1")
540}
541
542func runClientTestTLS11(t *testing.T, template *clientTest) {
543	runClientTestForVersion(t, template, "TLSv11", "-tls1_1")
544}
545
546func runClientTestTLS12(t *testing.T, template *clientTest) {
547	runClientTestForVersion(t, template, "TLSv12", "-tls1_2")
548}
549
550func runClientTestTLS13(t *testing.T, template *clientTest) {
551	runClientTestForVersion(t, template, "TLSv13", "-tls1_3")
552}
553
554func TestHandshakeClientRSARC4(t *testing.T) {
555	test := &clientTest{
556		name: "RSA-RC4",
557		args: []string{"-cipher", "RC4-SHA"},
558	}
559	runClientTestTLS10(t, test)
560	runClientTestTLS11(t, test)
561	runClientTestTLS12(t, test)
562}
563
564func TestHandshakeClientRSAAES128GCM(t *testing.T) {
565	test := &clientTest{
566		name: "AES128-GCM-SHA256",
567		args: []string{"-cipher", "AES128-GCM-SHA256"},
568	}
569	runClientTestTLS12(t, test)
570}
571
572func TestHandshakeClientRSAAES256GCM(t *testing.T) {
573	test := &clientTest{
574		name: "AES256-GCM-SHA384",
575		args: []string{"-cipher", "AES256-GCM-SHA384"},
576	}
577	runClientTestTLS12(t, test)
578}
579
580func TestHandshakeClientECDHERSAAES(t *testing.T) {
581	test := &clientTest{
582		name: "ECDHE-RSA-AES",
583		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA"},
584	}
585	runClientTestTLS10(t, test)
586	runClientTestTLS11(t, test)
587	runClientTestTLS12(t, test)
588}
589
590func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
591	test := &clientTest{
592		name: "ECDHE-ECDSA-AES",
593		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA"},
594		cert: testECDSACertificate,
595		key:  testECDSAPrivateKey,
596	}
597	runClientTestTLS10(t, test)
598	runClientTestTLS11(t, test)
599	runClientTestTLS12(t, test)
600}
601
602func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
603	test := &clientTest{
604		name: "ECDHE-ECDSA-AES-GCM",
605		args: []string{"-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
606		cert: testECDSACertificate,
607		key:  testECDSAPrivateKey,
608	}
609	runClientTestTLS12(t, test)
610}
611
612func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
613	test := &clientTest{
614		name: "ECDHE-ECDSA-AES256-GCM-SHA384",
615		args: []string{"-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
616		cert: testECDSACertificate,
617		key:  testECDSAPrivateKey,
618	}
619	runClientTestTLS12(t, test)
620}
621
622func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
623	test := &clientTest{
624		name: "AES128-SHA256",
625		args: []string{"-cipher", "AES128-SHA256"},
626	}
627	runClientTestTLS12(t, test)
628}
629
630func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
631	test := &clientTest{
632		name: "ECDHE-RSA-AES128-SHA256",
633		args: []string{"-cipher", "ECDHE-RSA-AES128-SHA256"},
634	}
635	runClientTestTLS12(t, test)
636}
637
638func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
639	test := &clientTest{
640		name: "ECDHE-ECDSA-AES128-SHA256",
641		args: []string{"-cipher", "ECDHE-ECDSA-AES128-SHA256"},
642		cert: testECDSACertificate,
643		key:  testECDSAPrivateKey,
644	}
645	runClientTestTLS12(t, test)
646}
647
648func TestHandshakeClientX25519(t *testing.T) {
649	config := testConfig.Clone()
650	config.CurvePreferences = []CurveID{X25519}
651
652	test := &clientTest{
653		name:   "X25519-ECDHE",
654		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "X25519"},
655		config: config,
656	}
657
658	runClientTestTLS12(t, test)
659	runClientTestTLS13(t, test)
660}
661
662func TestHandshakeClientP256(t *testing.T) {
663	config := testConfig.Clone()
664	config.CurvePreferences = []CurveID{CurveP256}
665
666	test := &clientTest{
667		name:   "P256-ECDHE",
668		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
669		config: config,
670	}
671
672	runClientTestTLS12(t, test)
673	runClientTestTLS13(t, test)
674}
675
676func TestHandshakeClientHelloRetryRequest(t *testing.T) {
677	config := testConfig.Clone()
678	config.CurvePreferences = []CurveID{X25519, CurveP256}
679
680	test := &clientTest{
681		name:   "HelloRetryRequest",
682		args:   []string{"-cipher", "ECDHE-RSA-AES128-GCM-SHA256", "-curves", "P-256"},
683		config: config,
684	}
685
686	runClientTestTLS13(t, test)
687}
688
689func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
690	config := testConfig.Clone()
691	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
692
693	test := &clientTest{
694		name:   "ECDHE-RSA-CHACHA20-POLY1305",
695		args:   []string{"-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
696		config: config,
697	}
698
699	runClientTestTLS12(t, test)
700}
701
702func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
703	config := testConfig.Clone()
704	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
705
706	test := &clientTest{
707		name:   "ECDHE-ECDSA-CHACHA20-POLY1305",
708		args:   []string{"-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
709		config: config,
710		cert:   testECDSACertificate,
711		key:    testECDSAPrivateKey,
712	}
713
714	runClientTestTLS12(t, test)
715}
716
717func TestHandshakeClientAES128SHA256(t *testing.T) {
718	test := &clientTest{
719		name: "AES128-SHA256",
720		args: []string{"-ciphersuites", "TLS_AES_128_GCM_SHA256"},
721	}
722	runClientTestTLS13(t, test)
723}
724func TestHandshakeClientAES256SHA384(t *testing.T) {
725	test := &clientTest{
726		name: "AES256-SHA384",
727		args: []string{"-ciphersuites", "TLS_AES_256_GCM_SHA384"},
728	}
729	runClientTestTLS13(t, test)
730}
731func TestHandshakeClientCHACHA20SHA256(t *testing.T) {
732	test := &clientTest{
733		name: "CHACHA20-SHA256",
734		args: []string{"-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"},
735	}
736	runClientTestTLS13(t, test)
737}
738
739func TestHandshakeClientECDSATLS13(t *testing.T) {
740	test := &clientTest{
741		name: "ECDSA",
742		cert: testECDSACertificate,
743		key:  testECDSAPrivateKey,
744	}
745	runClientTestTLS13(t, test)
746}
747
748func TestHandshakeClientCertRSA(t *testing.T) {
749	config := testConfig.Clone()
750	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
751	config.Certificates = []Certificate{cert}
752
753	test := &clientTest{
754		name:   "ClientCert-RSA-RSA",
755		args:   []string{"-cipher", "AES128", "-Verify", "1"},
756		config: config,
757	}
758
759	runClientTestTLS10(t, test)
760	runClientTestTLS12(t, test)
761
762	test = &clientTest{
763		name:   "ClientCert-RSA-ECDSA",
764		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
765		config: config,
766		cert:   testECDSACertificate,
767		key:    testECDSAPrivateKey,
768	}
769
770	runClientTestTLS10(t, test)
771	runClientTestTLS12(t, test)
772	runClientTestTLS13(t, test)
773
774	test = &clientTest{
775		name:   "ClientCert-RSA-AES256-GCM-SHA384",
776		args:   []string{"-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-Verify", "1"},
777		config: config,
778		cert:   testRSACertificate,
779		key:    testRSAPrivateKey,
780	}
781
782	runClientTestTLS12(t, test)
783}
784
785func TestHandshakeClientCertECDSA(t *testing.T) {
786	config := testConfig.Clone()
787	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
788	config.Certificates = []Certificate{cert}
789
790	test := &clientTest{
791		name:   "ClientCert-ECDSA-RSA",
792		args:   []string{"-cipher", "AES128", "-Verify", "1"},
793		config: config,
794	}
795
796	runClientTestTLS10(t, test)
797	runClientTestTLS12(t, test)
798	runClientTestTLS13(t, test)
799
800	test = &clientTest{
801		name:   "ClientCert-ECDSA-ECDSA",
802		args:   []string{"-cipher", "ECDHE-ECDSA-AES128-SHA", "-Verify", "1"},
803		config: config,
804		cert:   testECDSACertificate,
805		key:    testECDSAPrivateKey,
806	}
807
808	runClientTestTLS10(t, test)
809	runClientTestTLS12(t, test)
810}
811
812// TestHandshakeClientCertRSAPSS tests a few separate things:
813//  * that our client can serve a PSS-signed certificate
814//  * that our client can validate a PSS-signed certificate
815//  * that our client can use rsa_pss_rsae_sha256 in its CertificateVerify
816//  * that our client can accpet rsa_pss_rsae_sha256 in the server CertificateVerify
817func TestHandshakeClientCertRSAPSS(t *testing.T) {
818	issuer, err := x509.ParseCertificate(testRSAPSSCertificate)
819	if err != nil {
820		panic(err)
821	}
822	rootCAs := x509.NewCertPool()
823	rootCAs.AddCert(issuer)
824
825	config := testConfig.Clone()
826	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
827	config.Certificates = []Certificate{cert}
828	config.RootCAs = rootCAs
829
830	test := &clientTest{
831		name: "ClientCert-RSA-RSAPSS",
832		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
833			"rsa_pss_rsae_sha256", "-sigalgs", "rsa_pss_rsae_sha256"},
834		config: config,
835		cert:   testRSAPSSCertificate,
836		key:    testRSAPrivateKey,
837	}
838
839	runClientTestTLS12(t, test)
840	runClientTestTLS13(t, test)
841}
842
843func TestHandshakeClientCertRSAPKCS1v15(t *testing.T) {
844	config := testConfig.Clone()
845	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
846	config.Certificates = []Certificate{cert}
847
848	test := &clientTest{
849		name: "ClientCert-RSA-RSAPKCS1v15",
850		args: []string{"-cipher", "AES128", "-Verify", "1", "-client_sigalgs",
851			"rsa_pkcs1_sha256", "-sigalgs", "rsa_pkcs1_sha256"},
852		config: config,
853	}
854
855	runClientTestTLS12(t, test)
856}
857
858func TestClientKeyUpdate(t *testing.T) {
859	test := &clientTest{
860		name:          "KeyUpdate",
861		args:          []string{"-state"},
862		sendKeyUpdate: true,
863	}
864	runClientTestTLS13(t, test)
865}
866
867func TestResumption(t *testing.T) {
868	t.Run("TLSv12", func(t *testing.T) { testResumption(t, VersionTLS12) })
869	t.Run("TLSv13", func(t *testing.T) { testResumption(t, VersionTLS13) })
870}
871
872func testResumption(t *testing.T, version uint16) {
873	serverConfig := &Config{
874		MaxVersion:   version,
875		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
876		Certificates: testConfig.Certificates,
877	}
878
879	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
880	if err != nil {
881		panic(err)
882	}
883
884	rootCAs := x509.NewCertPool()
885	rootCAs.AddCert(issuer)
886
887	clientConfig := &Config{
888		MaxVersion:         version,
889		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
890		ClientSessionCache: NewLRUClientSessionCache(32),
891		RootCAs:            rootCAs,
892		ServerName:         "example.golang",
893	}
894
895	testResumeState := func(test string, didResume bool) {
896		_, hs, err := testHandshake(t, clientConfig, serverConfig)
897		if err != nil {
898			t.Fatalf("%s: handshake failed: %s", test, err)
899		}
900		if hs.DidResume != didResume {
901			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
902		}
903		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
904			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
905		}
906	}
907
908	getTicket := func() []byte {
909		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
910	}
911	deleteTicket := func() {
912		ticketKey := clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).sessionKey
913		clientConfig.ClientSessionCache.Put(ticketKey, nil)
914	}
915	corruptTicket := func() {
916		clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.masterSecret[0] ^= 0xff
917	}
918	randomKey := func() [32]byte {
919		var k [32]byte
920		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
921			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
922		}
923		return k
924	}
925
926	testResumeState("Handshake", false)
927	ticket := getTicket()
928	testResumeState("Resume", true)
929	if !bytes.Equal(ticket, getTicket()) && version != VersionTLS13 {
930		t.Fatal("first ticket doesn't match ticket after resumption")
931	}
932	if bytes.Equal(ticket, getTicket()) && version == VersionTLS13 {
933		t.Fatal("ticket didn't change after resumption")
934	}
935
936	key1 := randomKey()
937	serverConfig.SetSessionTicketKeys([][32]byte{key1})
938
939	testResumeState("InvalidSessionTicketKey", false)
940	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
941
942	key2 := randomKey()
943	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
944	ticket = getTicket()
945	testResumeState("KeyChange", true)
946	if bytes.Equal(ticket, getTicket()) {
947		t.Fatal("new ticket wasn't included while resuming")
948	}
949	testResumeState("KeyChangeFinish", true)
950
951	// Reset serverConfig to ensure that calling SetSessionTicketKeys
952	// before the serverConfig is used works.
953	serverConfig = &Config{
954		MaxVersion:   version,
955		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
956		Certificates: testConfig.Certificates,
957	}
958	serverConfig.SetSessionTicketKeys([][32]byte{key2})
959
960	testResumeState("FreshConfig", true)
961
962	// In TLS 1.3, cross-cipher suite resumption is allowed as long as the KDF
963	// hash matches. Also, Config.CipherSuites does not apply to TLS 1.3.
964	if version != VersionTLS13 {
965		clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
966		testResumeState("DifferentCipherSuite", false)
967		testResumeState("DifferentCipherSuiteRecovers", true)
968	}
969
970	deleteTicket()
971	testResumeState("WithoutSessionTicket", false)
972
973	// Session resumption should work when using client certificates
974	deleteTicket()
975	serverConfig.ClientCAs = rootCAs
976	serverConfig.ClientAuth = RequireAndVerifyClientCert
977	clientConfig.Certificates = serverConfig.Certificates
978	testResumeState("InitialHandshake", false)
979	testResumeState("WithClientCertificates", true)
980	serverConfig.ClientAuth = NoClientCert
981
982	// Tickets should be removed from the session cache on TLS handshake
983	// failure, and the client should recover from a corrupted PSK
984	testResumeState("FetchTicketToCorrupt", false)
985	corruptTicket()
986	_, _, err = testHandshake(t, clientConfig, serverConfig)
987	if err == nil {
988		t.Fatalf("handshake did not fail with a corrupted client secret")
989	}
990	testResumeState("AfterHandshakeFailure", false)
991
992	clientConfig.ClientSessionCache = nil
993	testResumeState("WithoutSessionCache", false)
994}
995
996func TestLRUClientSessionCache(t *testing.T) {
997	// Initialize cache of capacity 4.
998	cache := NewLRUClientSessionCache(4)
999	cs := make([]ClientSessionState, 6)
1000	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
1001
1002	// Add 4 entries to the cache and look them up.
1003	for i := 0; i < 4; i++ {
1004		cache.Put(keys[i], &cs[i])
1005	}
1006	for i := 0; i < 4; i++ {
1007		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1008			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
1009		}
1010	}
1011
1012	// Add 2 more entries to the cache. First 2 should be evicted.
1013	for i := 4; i < 6; i++ {
1014		cache.Put(keys[i], &cs[i])
1015	}
1016	for i := 0; i < 2; i++ {
1017		if s, ok := cache.Get(keys[i]); ok || s != nil {
1018			t.Fatalf("session cache should have evicted key: %s", keys[i])
1019		}
1020	}
1021
1022	// Touch entry 2. LRU should evict 3 next.
1023	cache.Get(keys[2])
1024	cache.Put(keys[0], &cs[0])
1025	if s, ok := cache.Get(keys[3]); ok || s != nil {
1026		t.Fatalf("session cache should have evicted key 3")
1027	}
1028
1029	// Update entry 0 in place.
1030	cache.Put(keys[0], &cs[3])
1031	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
1032		t.Fatalf("session cache failed update for key 0")
1033	}
1034
1035	// Calling Put with a nil entry deletes the key.
1036	cache.Put(keys[0], nil)
1037	if _, ok := cache.Get(keys[0]); ok {
1038		t.Fatalf("session cache failed to delete key 0")
1039	}
1040
1041	// Delete entry 2. LRU should keep 4 and 5
1042	cache.Put(keys[2], nil)
1043	if _, ok := cache.Get(keys[2]); ok {
1044		t.Fatalf("session cache failed to delete key 4")
1045	}
1046	for i := 4; i < 6; i++ {
1047		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
1048			t.Fatalf("session cache should not have deleted key: %s", keys[i])
1049		}
1050	}
1051}
1052
1053func TestKeyLogTLS12(t *testing.T) {
1054	var serverBuf, clientBuf bytes.Buffer
1055
1056	clientConfig := testConfig.Clone()
1057	clientConfig.KeyLogWriter = &clientBuf
1058	clientConfig.MaxVersion = VersionTLS12
1059
1060	serverConfig := testConfig.Clone()
1061	serverConfig.KeyLogWriter = &serverBuf
1062	serverConfig.MaxVersion = VersionTLS12
1063
1064	c, s := localPipe(t)
1065	done := make(chan bool)
1066
1067	go func() {
1068		defer close(done)
1069
1070		if err := Server(s, serverConfig).Handshake(); err != nil {
1071			t.Errorf("server: %s", err)
1072			return
1073		}
1074		s.Close()
1075	}()
1076
1077	if err := Client(c, clientConfig).Handshake(); err != nil {
1078		t.Fatalf("client: %s", err)
1079	}
1080
1081	c.Close()
1082	<-done
1083
1084	checkKeylogLine := func(side, loggedLine string) {
1085		if len(loggedLine) == 0 {
1086			t.Fatalf("%s: no keylog line was produced", side)
1087		}
1088		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
1089			1 /* space */ +
1090			32*2 /* hex client nonce */ +
1091			1 /* space */ +
1092			48*2 /* hex master secret */ +
1093			1 /* new line */
1094		if len(loggedLine) != expectedLen {
1095			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
1096		}
1097		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
1098			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
1099		}
1100	}
1101
1102	checkKeylogLine("client", clientBuf.String())
1103	checkKeylogLine("server", serverBuf.String())
1104}
1105
1106func TestKeyLogTLS13(t *testing.T) {
1107	var serverBuf, clientBuf bytes.Buffer
1108
1109	clientConfig := testConfig.Clone()
1110	clientConfig.KeyLogWriter = &clientBuf
1111
1112	serverConfig := testConfig.Clone()
1113	serverConfig.KeyLogWriter = &serverBuf
1114
1115	c, s := localPipe(t)
1116	done := make(chan bool)
1117
1118	go func() {
1119		defer close(done)
1120
1121		if err := Server(s, serverConfig).Handshake(); err != nil {
1122			t.Errorf("server: %s", err)
1123			return
1124		}
1125		s.Close()
1126	}()
1127
1128	if err := Client(c, clientConfig).Handshake(); err != nil {
1129		t.Fatalf("client: %s", err)
1130	}
1131
1132	c.Close()
1133	<-done
1134
1135	checkKeylogLines := func(side, loggedLines string) {
1136		loggedLines = strings.TrimSpace(loggedLines)
1137		lines := strings.Split(loggedLines, "\n")
1138		if len(lines) != 4 {
1139			t.Errorf("Expected the %s to log 4 lines, got %d", side, len(lines))
1140		}
1141	}
1142
1143	checkKeylogLines("client", clientBuf.String())
1144	checkKeylogLines("server", serverBuf.String())
1145}
1146
1147func TestHandshakeClientALPNMatch(t *testing.T) {
1148	config := testConfig.Clone()
1149	config.NextProtos = []string{"proto2", "proto1"}
1150
1151	test := &clientTest{
1152		name: "ALPN",
1153		// Note that this needs OpenSSL 1.0.2 because that is the first
1154		// version that supports the -alpn flag.
1155		args:   []string{"-alpn", "proto1,proto2"},
1156		config: config,
1157		validate: func(state ConnectionState) error {
1158			// The server's preferences should override the client.
1159			if state.NegotiatedProtocol != "proto1" {
1160				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
1161			}
1162			return nil
1163		},
1164	}
1165	runClientTestTLS12(t, test)
1166	runClientTestTLS13(t, test)
1167}
1168
1169// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
1170const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
1171
1172func TestHandshakClientSCTs(t *testing.T) {
1173	config := testConfig.Clone()
1174
1175	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
1176	if err != nil {
1177		t.Fatal(err)
1178	}
1179
1180	// Note that this needs OpenSSL 1.0.2 because that is the first
1181	// version that supports the -serverinfo flag.
1182	test := &clientTest{
1183		name:       "SCT",
1184		config:     config,
1185		extensions: [][]byte{scts},
1186		validate: func(state ConnectionState) error {
1187			expectedSCTs := [][]byte{
1188				scts[8:125],
1189				scts[127:245],
1190				scts[247:],
1191			}
1192			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
1193				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
1194			}
1195			for i, expected := range expectedSCTs {
1196				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
1197					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
1198				}
1199			}
1200			return nil
1201		},
1202	}
1203	runClientTestTLS12(t, test)
1204
1205	// TLS 1.3 moved SCTs to the Certificate extensions and -serverinfo only
1206	// supports ServerHello extensions.
1207}
1208
1209func TestRenegotiationRejected(t *testing.T) {
1210	config := testConfig.Clone()
1211	test := &clientTest{
1212		name:                        "RenegotiationRejected",
1213		args:                        []string{"-state"},
1214		config:                      config,
1215		numRenegotiations:           1,
1216		renegotiationExpectedToFail: 1,
1217		checkRenegotiationError: func(renegotiationNum int, err error) error {
1218			if err == nil {
1219				return errors.New("expected error from renegotiation but got nil")
1220			}
1221			if !strings.Contains(err.Error(), "no renegotiation") {
1222				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1223			}
1224			return nil
1225		},
1226	}
1227	runClientTestTLS12(t, test)
1228}
1229
1230func TestRenegotiateOnce(t *testing.T) {
1231	config := testConfig.Clone()
1232	config.Renegotiation = RenegotiateOnceAsClient
1233
1234	test := &clientTest{
1235		name:              "RenegotiateOnce",
1236		args:              []string{"-state"},
1237		config:            config,
1238		numRenegotiations: 1,
1239	}
1240
1241	runClientTestTLS12(t, test)
1242}
1243
1244func TestRenegotiateTwice(t *testing.T) {
1245	config := testConfig.Clone()
1246	config.Renegotiation = RenegotiateFreelyAsClient
1247
1248	test := &clientTest{
1249		name:              "RenegotiateTwice",
1250		args:              []string{"-state"},
1251		config:            config,
1252		numRenegotiations: 2,
1253	}
1254
1255	runClientTestTLS12(t, test)
1256}
1257
1258func TestRenegotiateTwiceRejected(t *testing.T) {
1259	config := testConfig.Clone()
1260	config.Renegotiation = RenegotiateOnceAsClient
1261
1262	test := &clientTest{
1263		name:                        "RenegotiateTwiceRejected",
1264		args:                        []string{"-state"},
1265		config:                      config,
1266		numRenegotiations:           2,
1267		renegotiationExpectedToFail: 2,
1268		checkRenegotiationError: func(renegotiationNum int, err error) error {
1269			if renegotiationNum == 1 {
1270				return err
1271			}
1272
1273			if err == nil {
1274				return errors.New("expected error from renegotiation but got nil")
1275			}
1276			if !strings.Contains(err.Error(), "no renegotiation") {
1277				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
1278			}
1279			return nil
1280		},
1281	}
1282
1283	runClientTestTLS12(t, test)
1284}
1285
1286func TestHandshakeClientExportKeyingMaterial(t *testing.T) {
1287	test := &clientTest{
1288		name:   "ExportKeyingMaterial",
1289		config: testConfig.Clone(),
1290		validate: func(state ConnectionState) error {
1291			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
1292				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
1293			} else if len(km) != 42 {
1294				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
1295			}
1296			return nil
1297		},
1298	}
1299	runClientTestTLS10(t, test)
1300	runClientTestTLS12(t, test)
1301	runClientTestTLS13(t, test)
1302}
1303
1304var hostnameInSNITests = []struct {
1305	in, out string
1306}{
1307	// Opaque string
1308	{"", ""},
1309	{"localhost", "localhost"},
1310	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
1311
1312	// DNS hostname
1313	{"golang.org", "golang.org"},
1314	{"golang.org.", "golang.org"},
1315
1316	// Literal IPv4 address
1317	{"1.2.3.4", ""},
1318
1319	// Literal IPv6 address
1320	{"::1", ""},
1321	{"::1%lo0", ""}, // with zone identifier
1322	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1323	{"[::1%lo0]", ""},
1324}
1325
1326func TestHostnameInSNI(t *testing.T) {
1327	for _, tt := range hostnameInSNITests {
1328		c, s := localPipe(t)
1329
1330		go func(host string) {
1331			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1332		}(tt.in)
1333
1334		var header [5]byte
1335		if _, err := io.ReadFull(s, header[:]); err != nil {
1336			t.Fatal(err)
1337		}
1338		recordLen := int(header[3])<<8 | int(header[4])
1339
1340		record := make([]byte, recordLen)
1341		if _, err := io.ReadFull(s, record[:]); err != nil {
1342			t.Fatal(err)
1343		}
1344
1345		c.Close()
1346		s.Close()
1347
1348		var m clientHelloMsg
1349		if !m.unmarshal(record) {
1350			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1351			continue
1352		}
1353		if tt.in != tt.out && m.serverName == tt.in {
1354			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1355		}
1356		if m.serverName != tt.out {
1357			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1358		}
1359	}
1360}
1361
1362func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1363	// This checks that the server can't select a cipher suite that the
1364	// client didn't offer. See #13174.
1365
1366	c, s := localPipe(t)
1367	errChan := make(chan error, 1)
1368
1369	go func() {
1370		client := Client(c, &Config{
1371			ServerName:   "foo",
1372			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1373		})
1374		errChan <- client.Handshake()
1375	}()
1376
1377	var header [5]byte
1378	if _, err := io.ReadFull(s, header[:]); err != nil {
1379		t.Fatal(err)
1380	}
1381	recordLen := int(header[3])<<8 | int(header[4])
1382
1383	record := make([]byte, recordLen)
1384	if _, err := io.ReadFull(s, record); err != nil {
1385		t.Fatal(err)
1386	}
1387
1388	// Create a ServerHello that selects a different cipher suite than the
1389	// sole one that the client offered.
1390	serverHello := &serverHelloMsg{
1391		vers:        VersionTLS12,
1392		random:      make([]byte, 32),
1393		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1394	}
1395	serverHelloBytes := serverHello.marshal()
1396
1397	s.Write([]byte{
1398		byte(recordTypeHandshake),
1399		byte(VersionTLS12 >> 8),
1400		byte(VersionTLS12 & 0xff),
1401		byte(len(serverHelloBytes) >> 8),
1402		byte(len(serverHelloBytes)),
1403	})
1404	s.Write(serverHelloBytes)
1405	s.Close()
1406
1407	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1408		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1409	}
1410}
1411
1412func TestVerifyPeerCertificate(t *testing.T) {
1413	t.Run("TLSv12", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS12) })
1414	t.Run("TLSv13", func(t *testing.T) { testVerifyPeerCertificate(t, VersionTLS13) })
1415}
1416
1417func testVerifyPeerCertificate(t *testing.T, version uint16) {
1418	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1419	if err != nil {
1420		panic(err)
1421	}
1422
1423	rootCAs := x509.NewCertPool()
1424	rootCAs.AddCert(issuer)
1425
1426	now := func() time.Time { return time.Unix(1476984729, 0) }
1427
1428	sentinelErr := errors.New("TestVerifyPeerCertificate")
1429
1430	verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1431		if l := len(rawCerts); l != 1 {
1432			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1433		}
1434		if len(validatedChains) == 0 {
1435			return errors.New("got len(validatedChains) = 0, wanted non-zero")
1436		}
1437		*called = true
1438		return nil
1439	}
1440
1441	tests := []struct {
1442		configureServer func(*Config, *bool)
1443		configureClient func(*Config, *bool)
1444		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1445	}{
1446		{
1447			configureServer: func(config *Config, called *bool) {
1448				config.InsecureSkipVerify = false
1449				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1450					return verifyCallback(called, rawCerts, validatedChains)
1451				}
1452			},
1453			configureClient: func(config *Config, called *bool) {
1454				config.InsecureSkipVerify = false
1455				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1456					return verifyCallback(called, rawCerts, validatedChains)
1457				}
1458			},
1459			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1460				if clientErr != nil {
1461					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1462				}
1463				if serverErr != nil {
1464					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1465				}
1466				if !clientCalled {
1467					t.Errorf("test[%d]: client did not call callback", testNo)
1468				}
1469				if !serverCalled {
1470					t.Errorf("test[%d]: server did not call callback", testNo)
1471				}
1472			},
1473		},
1474		{
1475			configureServer: func(config *Config, called *bool) {
1476				config.InsecureSkipVerify = false
1477				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1478					return sentinelErr
1479				}
1480			},
1481			configureClient: func(config *Config, called *bool) {
1482				config.VerifyPeerCertificate = nil
1483			},
1484			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1485				if serverErr != sentinelErr {
1486					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1487				}
1488			},
1489		},
1490		{
1491			configureServer: func(config *Config, called *bool) {
1492				config.InsecureSkipVerify = false
1493			},
1494			configureClient: func(config *Config, called *bool) {
1495				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1496					return sentinelErr
1497				}
1498			},
1499			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1500				if clientErr != sentinelErr {
1501					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1502				}
1503			},
1504		},
1505		{
1506			configureServer: func(config *Config, called *bool) {
1507				config.InsecureSkipVerify = false
1508			},
1509			configureClient: func(config *Config, called *bool) {
1510				config.InsecureSkipVerify = true
1511				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1512					if l := len(rawCerts); l != 1 {
1513						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1514					}
1515					// With InsecureSkipVerify set, this
1516					// callback should still be called but
1517					// validatedChains must be empty.
1518					if l := len(validatedChains); l != 0 {
1519						return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l)
1520					}
1521					*called = true
1522					return nil
1523				}
1524			},
1525			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1526				if clientErr != nil {
1527					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1528				}
1529				if serverErr != nil {
1530					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1531				}
1532				if !clientCalled {
1533					t.Errorf("test[%d]: client did not call callback", testNo)
1534				}
1535			},
1536		},
1537	}
1538
1539	for i, test := range tests {
1540		c, s := localPipe(t)
1541		done := make(chan error)
1542
1543		var clientCalled, serverCalled bool
1544
1545		go func() {
1546			config := testConfig.Clone()
1547			config.ServerName = "example.golang"
1548			config.ClientAuth = RequireAndVerifyClientCert
1549			config.ClientCAs = rootCAs
1550			config.Time = now
1551			config.MaxVersion = version
1552			test.configureServer(config, &serverCalled)
1553
1554			err = Server(s, config).Handshake()
1555			s.Close()
1556			done <- err
1557		}()
1558
1559		config := testConfig.Clone()
1560		config.ServerName = "example.golang"
1561		config.RootCAs = rootCAs
1562		config.Time = now
1563		config.MaxVersion = version
1564		test.configureClient(config, &clientCalled)
1565		clientErr := Client(c, config).Handshake()
1566		c.Close()
1567		serverErr := <-done
1568
1569		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
1570	}
1571}
1572
1573// brokenConn wraps a net.Conn and causes all Writes after a certain number to
1574// fail with brokenConnErr.
1575type brokenConn struct {
1576	net.Conn
1577
1578	// breakAfter is the number of successful writes that will be allowed
1579	// before all subsequent writes fail.
1580	breakAfter int
1581
1582	// numWrites is the number of writes that have been done.
1583	numWrites int
1584}
1585
1586// brokenConnErr is the error that brokenConn returns once exhausted.
1587var brokenConnErr = errors.New("too many writes to brokenConn")
1588
1589func (b *brokenConn) Write(data []byte) (int, error) {
1590	if b.numWrites >= b.breakAfter {
1591		return 0, brokenConnErr
1592	}
1593
1594	b.numWrites++
1595	return b.Conn.Write(data)
1596}
1597
1598func TestFailedWrite(t *testing.T) {
1599	// Test that a write error during the handshake is returned.
1600	for _, breakAfter := range []int{0, 1} {
1601		c, s := localPipe(t)
1602		done := make(chan bool)
1603
1604		go func() {
1605			Server(s, testConfig).Handshake()
1606			s.Close()
1607			done <- true
1608		}()
1609
1610		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
1611		err := Client(brokenC, testConfig).Handshake()
1612		if err != brokenConnErr {
1613			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
1614		}
1615		brokenC.Close()
1616
1617		<-done
1618	}
1619}
1620
1621// writeCountingConn wraps a net.Conn and counts the number of Write calls.
1622type writeCountingConn struct {
1623	net.Conn
1624
1625	// numWrites is the number of writes that have been done.
1626	numWrites int
1627}
1628
1629func (wcc *writeCountingConn) Write(data []byte) (int, error) {
1630	wcc.numWrites++
1631	return wcc.Conn.Write(data)
1632}
1633
1634func TestBuffering(t *testing.T) {
1635	t.Run("TLSv12", func(t *testing.T) { testBuffering(t, VersionTLS12) })
1636	t.Run("TLSv13", func(t *testing.T) { testBuffering(t, VersionTLS13) })
1637}
1638
1639func testBuffering(t *testing.T, version uint16) {
1640	c, s := localPipe(t)
1641	done := make(chan bool)
1642
1643	clientWCC := &writeCountingConn{Conn: c}
1644	serverWCC := &writeCountingConn{Conn: s}
1645
1646	go func() {
1647		config := testConfig.Clone()
1648		config.MaxVersion = version
1649		Server(serverWCC, config).Handshake()
1650		serverWCC.Close()
1651		done <- true
1652	}()
1653
1654	err := Client(clientWCC, testConfig).Handshake()
1655	if err != nil {
1656		t.Fatal(err)
1657	}
1658	clientWCC.Close()
1659	<-done
1660
1661	var expectedClient, expectedServer int
1662	if version == VersionTLS13 {
1663		expectedClient = 2
1664		expectedServer = 1
1665	} else {
1666		expectedClient = 2
1667		expectedServer = 2
1668	}
1669
1670	if n := clientWCC.numWrites; n != expectedClient {
1671		t.Errorf("expected client handshake to complete with %d writes, but saw %d", expectedClient, n)
1672	}
1673
1674	if n := serverWCC.numWrites; n != expectedServer {
1675		t.Errorf("expected server handshake to complete with %d writes, but saw %d", expectedServer, n)
1676	}
1677}
1678
1679func TestAlertFlushing(t *testing.T) {
1680	c, s := localPipe(t)
1681	done := make(chan bool)
1682
1683	clientWCC := &writeCountingConn{Conn: c}
1684	serverWCC := &writeCountingConn{Conn: s}
1685
1686	serverConfig := testConfig.Clone()
1687
1688	// Cause a signature-time error
1689	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
1690	brokenKey.D = big.NewInt(42)
1691	serverConfig.Certificates = []Certificate{{
1692		Certificate: [][]byte{testRSACertificate},
1693		PrivateKey:  &brokenKey,
1694	}}
1695
1696	go func() {
1697		Server(serverWCC, serverConfig).Handshake()
1698		serverWCC.Close()
1699		done <- true
1700	}()
1701
1702	err := Client(clientWCC, testConfig).Handshake()
1703	if err == nil {
1704		t.Fatal("client unexpectedly returned no error")
1705	}
1706
1707	const expectedError = "remote error: tls: internal error"
1708	if e := err.Error(); !strings.Contains(e, expectedError) {
1709		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
1710	}
1711	clientWCC.Close()
1712	<-done
1713
1714	if n := serverWCC.numWrites; n != 1 {
1715		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
1716	}
1717}
1718
1719func TestHandshakeRace(t *testing.T) {
1720	t.Parallel()
1721	// This test races a Read and Write to try and complete a handshake in
1722	// order to provide some evidence that there are no races or deadlocks
1723	// in the handshake locking.
1724	for i := 0; i < 32; i++ {
1725		c, s := localPipe(t)
1726
1727		go func() {
1728			server := Server(s, testConfig)
1729			if err := server.Handshake(); err != nil {
1730				panic(err)
1731			}
1732
1733			var request [1]byte
1734			if n, err := server.Read(request[:]); err != nil || n != 1 {
1735				panic(err)
1736			}
1737
1738			server.Write(request[:])
1739			server.Close()
1740		}()
1741
1742		startWrite := make(chan struct{})
1743		startRead := make(chan struct{})
1744		readDone := make(chan struct{})
1745
1746		client := Client(c, testConfig)
1747		go func() {
1748			<-startWrite
1749			var request [1]byte
1750			client.Write(request[:])
1751		}()
1752
1753		go func() {
1754			<-startRead
1755			var reply [1]byte
1756			if _, err := io.ReadFull(client, reply[:]); err != nil {
1757				panic(err)
1758			}
1759			c.Close()
1760			readDone <- struct{}{}
1761		}()
1762
1763		if i&1 == 1 {
1764			startWrite <- struct{}{}
1765			startRead <- struct{}{}
1766		} else {
1767			startRead <- struct{}{}
1768			startWrite <- struct{}{}
1769		}
1770		<-readDone
1771	}
1772}
1773
1774var getClientCertificateTests = []struct {
1775	setup               func(*Config, *Config)
1776	expectedClientError string
1777	verify              func(*testing.T, int, *ConnectionState)
1778}{
1779	{
1780		func(clientConfig, serverConfig *Config) {
1781			// Returning a Certificate with no certificate data
1782			// should result in an empty message being sent to the
1783			// server.
1784			serverConfig.ClientCAs = nil
1785			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1786				if len(cri.SignatureSchemes) == 0 {
1787					panic("empty SignatureSchemes")
1788				}
1789				if len(cri.AcceptableCAs) != 0 {
1790					panic("AcceptableCAs should have been empty")
1791				}
1792				return new(Certificate), nil
1793			}
1794		},
1795		"",
1796		func(t *testing.T, testNum int, cs *ConnectionState) {
1797			if l := len(cs.PeerCertificates); l != 0 {
1798				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
1799			}
1800		},
1801	},
1802	{
1803		func(clientConfig, serverConfig *Config) {
1804			// With TLS 1.1, the SignatureSchemes should be
1805			// synthesised from the supported certificate types.
1806			clientConfig.MaxVersion = VersionTLS11
1807			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1808				if len(cri.SignatureSchemes) == 0 {
1809					panic("empty SignatureSchemes")
1810				}
1811				return new(Certificate), nil
1812			}
1813		},
1814		"",
1815		func(t *testing.T, testNum int, cs *ConnectionState) {
1816			if l := len(cs.PeerCertificates); l != 0 {
1817				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
1818			}
1819		},
1820	},
1821	{
1822		func(clientConfig, serverConfig *Config) {
1823			// Returning an error should abort the handshake with
1824			// that error.
1825			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1826				return nil, errors.New("GetClientCertificate")
1827			}
1828		},
1829		"GetClientCertificate",
1830		func(t *testing.T, testNum int, cs *ConnectionState) {
1831		},
1832	},
1833	{
1834		func(clientConfig, serverConfig *Config) {
1835			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1836				if len(cri.AcceptableCAs) == 0 {
1837					panic("empty AcceptableCAs")
1838				}
1839				cert := &Certificate{
1840					Certificate: [][]byte{testRSACertificate},
1841					PrivateKey:  testRSAPrivateKey,
1842				}
1843				return cert, nil
1844			}
1845		},
1846		"",
1847		func(t *testing.T, testNum int, cs *ConnectionState) {
1848			if len(cs.VerifiedChains) == 0 {
1849				t.Errorf("#%d: expected some verified chains, but found none", testNum)
1850			}
1851		},
1852	},
1853}
1854
1855func TestGetClientCertificate(t *testing.T) {
1856	t.Run("TLSv12", func(t *testing.T) { testGetClientCertificate(t, VersionTLS12) })
1857	t.Run("TLSv13", func(t *testing.T) { testGetClientCertificate(t, VersionTLS13) })
1858}
1859
1860func testGetClientCertificate(t *testing.T, version uint16) {
1861	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1862	if err != nil {
1863		panic(err)
1864	}
1865
1866	for i, test := range getClientCertificateTests {
1867		serverConfig := testConfig.Clone()
1868		serverConfig.ClientAuth = VerifyClientCertIfGiven
1869		serverConfig.RootCAs = x509.NewCertPool()
1870		serverConfig.RootCAs.AddCert(issuer)
1871		serverConfig.ClientCAs = serverConfig.RootCAs
1872		serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) }
1873		serverConfig.MaxVersion = version
1874
1875		clientConfig := testConfig.Clone()
1876		clientConfig.MaxVersion = version
1877
1878		test.setup(clientConfig, serverConfig)
1879
1880		type serverResult struct {
1881			cs  ConnectionState
1882			err error
1883		}
1884
1885		c, s := localPipe(t)
1886		done := make(chan serverResult)
1887
1888		go func() {
1889			defer s.Close()
1890			server := Server(s, serverConfig)
1891			err := server.Handshake()
1892
1893			var cs ConnectionState
1894			if err == nil {
1895				cs = server.ConnectionState()
1896			}
1897			done <- serverResult{cs, err}
1898		}()
1899
1900		clientErr := Client(c, clientConfig).Handshake()
1901		c.Close()
1902
1903		result := <-done
1904
1905		if clientErr != nil {
1906			if len(test.expectedClientError) == 0 {
1907				t.Errorf("#%d: client error: %v", i, clientErr)
1908			} else if got := clientErr.Error(); got != test.expectedClientError {
1909				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
1910			} else {
1911				test.verify(t, i, &result.cs)
1912			}
1913		} else if len(test.expectedClientError) > 0 {
1914			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
1915		} else if err := result.err; err != nil {
1916			t.Errorf("#%d: server error: %v", i, err)
1917		} else {
1918			test.verify(t, i, &result.cs)
1919		}
1920	}
1921}
1922
1923func TestRSAPSSKeyError(t *testing.T) {
1924	// crypto/tls does not support the rsa_pss_pss_* SignatureSchemes. If support for
1925	// public keys with OID RSASSA-PSS is added to crypto/x509, they will be misused with
1926	// the rsa_pss_rsae_* SignatureSchemes. Assert that RSASSA-PSS certificates don't
1927	// parse, or that they don't carry *rsa.PublicKey keys.
1928	b, _ := pem.Decode([]byte(`
1929-----BEGIN CERTIFICATE-----
1930MIIDZTCCAhygAwIBAgIUCF2x0FyTgZG0CC9QTDjGWkB5vgEwPgYJKoZIhvcNAQEK
1931MDGgDTALBglghkgBZQMEAgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQC
1932AgDeMBIxEDAOBgNVBAMMB1JTQS1QU1MwHhcNMTgwNjI3MjI0NDM2WhcNMTgwNzI3
1933MjI0NDM2WjASMRAwDgYDVQQDDAdSU0EtUFNTMIIBIDALBgkqhkiG9w0BAQoDggEP
1934ADCCAQoCggEBANxDm0f76JdI06YzsjB3AmmjIYkwUEGxePlafmIASFjDZl/elD0Z
1935/a7xLX468b0qGxLS5al7XCcEprSdsDR6DF5L520+pCbpfLyPOjuOvGmk9KzVX4x5
1936b05YXYuXdsQ0Kjxcx2i3jjCday6scIhMJVgBZxTEyMj1thPQM14SHzKCd/m6HmCL
1937QmswpH2yMAAcBRWzRpp/vdH5DeOJEB3aelq7094no731mrLUCHRiZ1htq8BDB3ou
1938czwqgwspbqZ4dnMXl2MvfySQ5wJUxQwILbiuAKO2lVVPUbFXHE9pgtznNoPvKwQT
1939JNcX8ee8WIZc2SEGzofjk3NpjR+2ADB2u3sCAwEAAaNTMFEwHQYDVR0OBBYEFNEz
1940AdyJ2f+fU+vSCS6QzohnOnprMB8GA1UdIwQYMBaAFNEzAdyJ2f+fU+vSCS6Qzohn
1941OnprMA8GA1UdEwEB/wQFMAMBAf8wPgYJKoZIhvcNAQEKMDGgDTALBglghkgBZQME
1942AgGhGjAYBgkqhkiG9w0BAQgwCwYJYIZIAWUDBAIBogQCAgDeA4IBAQCjEdrR5aab
1943sZmCwrMeKidXgfkmWvfuLDE+TCbaqDZp7BMWcMQXT9O0UoUT5kqgKj2ARm2pEW0Z
1944H3Z1vj3bbds72qcDIJXp+l0fekyLGeCrX/CbgnMZXEP7+/+P416p34ChR1Wz4dU1
1945KD3gdsUuTKKeMUog3plxlxQDhRQmiL25ygH1LmjLd6dtIt0GVRGr8lj3euVeprqZ
1946bZ3Uq5eLfsn8oPgfC57gpO6yiN+UURRTlK3bgYvLh4VWB3XXk9UaQZ7Mq1tpXjoD
1947HYFybkWzibkZp4WRo+Fa28rirH+/wHt0vfeN7UCceURZEx4JaxIIfe4ku7uDRhJi
1948RwBA9Xk1KBNF
1949-----END CERTIFICATE-----`))
1950	if b == nil {
1951		t.Fatal("Failed to decode certificate")
1952	}
1953	cert, err := x509.ParseCertificate(b.Bytes)
1954	if err != nil {
1955		return
1956	}
1957	if _, ok := cert.PublicKey.(*rsa.PublicKey); ok {
1958		t.Error("A RSASSA-PSS certificate was parsed like a PKCS#1 v1.5 one, and it will be mistakenly used with rsa_pss_rsae_* signature algorithms")
1959	}
1960}
1961
1962func TestCloseClientConnectionOnIdleServer(t *testing.T) {
1963	clientConn, serverConn := localPipe(t)
1964	client := Client(clientConn, testConfig.Clone())
1965	go func() {
1966		var b [1]byte
1967		serverConn.Read(b[:])
1968		client.Close()
1969	}()
1970	client.SetWriteDeadline(time.Now().Add(time.Minute))
1971	err := client.Handshake()
1972	if err != nil {
1973		if err, ok := err.(net.Error); ok && err.Timeout() {
1974			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
1975		}
1976	} else {
1977		t.Errorf("Error expected, but no error returned")
1978	}
1979}
1980