1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"crypto/ecdsa"
10	"crypto/rsa"
11	"crypto/x509"
12	"encoding/base64"
13	"encoding/binary"
14	"encoding/pem"
15	"errors"
16	"fmt"
17	"io"
18	"math/big"
19	"net"
20	"os"
21	"os/exec"
22	"path/filepath"
23	"strconv"
24	"strings"
25	"sync"
26	"testing"
27	"time"
28)
29
30// Note: see comment in handshake_test.go for details of how the reference
31// tests work.
32
33// opensslInputEvent enumerates possible inputs that can be sent to an `openssl
34// s_client` process.
35type opensslInputEvent int
36
37const (
38	// opensslRenegotiate causes OpenSSL to request a renegotiation of the
39	// connection.
40	opensslRenegotiate opensslInputEvent = iota
41
42	// opensslSendBanner causes OpenSSL to send the contents of
43	// opensslSentinel on the connection.
44	opensslSendSentinel
45)
46
47const opensslSentinel = "SENTINEL\n"
48
49type opensslInput chan opensslInputEvent
50
51func (i opensslInput) Read(buf []byte) (n int, err error) {
52	for event := range i {
53		switch event {
54		case opensslRenegotiate:
55			return copy(buf, []byte("R\n")), nil
56		case opensslSendSentinel:
57			return copy(buf, []byte(opensslSentinel)), nil
58		default:
59			panic("unknown event")
60		}
61	}
62
63	return 0, io.EOF
64}
65
66// opensslOutputSink is an io.Writer that receives the stdout and stderr from
67// an `openssl` process and sends a value to handshakeComplete when it sees a
68// log message from a completed server handshake.
69type opensslOutputSink struct {
70	handshakeComplete chan struct{}
71	all               []byte
72	line              []byte
73}
74
75func newOpensslOutputSink() *opensslOutputSink {
76	return &opensslOutputSink{make(chan struct{}), nil, nil}
77}
78
79// opensslEndOfHandshake is a message that the “openssl s_server” tool will
80// print when a handshake completes if run with “-state”.
81const opensslEndOfHandshake = "SSL_accept:SSLv3/TLS write finished"
82
83func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
84	o.line = append(o.line, data...)
85	o.all = append(o.all, data...)
86
87	for {
88		i := bytes.Index(o.line, []byte{'\n'})
89		if i < 0 {
90			break
91		}
92
93		if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
94			o.handshakeComplete <- struct{}{}
95		}
96		o.line = o.line[i+1:]
97	}
98
99	return len(data), nil
100}
101
102func (o *opensslOutputSink) WriteTo(w io.Writer) (int64, error) {
103	n, err := w.Write(o.all)
104	return int64(n), err
105}
106
107// clientTest represents a test of the TLS client handshake against a reference
108// implementation.
109type clientTest struct {
110	// name is a freeform string identifying the test and the file in which
111	// the expected results will be stored.
112	name string
113	// command, if not empty, contains a series of arguments for the
114	// command to run for the reference server.
115	command []string
116	// config, if not nil, contains a custom Config to use for this test.
117	config *Config
118	// cert, if not empty, contains a DER-encoded certificate for the
119	// reference server.
120	cert []byte
121	// key, if not nil, contains either a *rsa.PrivateKey or
122	// *ecdsa.PrivateKey which is the private key for the reference server.
123	key interface{}
124	// extensions, if not nil, contains a list of extension data to be returned
125	// from the ServerHello. The data should be in standard TLS format with
126	// a 2-byte uint16 type, 2-byte data length, followed by the extension data.
127	extensions [][]byte
128	// validate, if not nil, is a function that will be called with the
129	// ConnectionState of the resulting connection. It returns a non-nil
130	// error if the ConnectionState is unacceptable.
131	validate func(ConnectionState) error
132	// numRenegotiations is the number of times that the connection will be
133	// renegotiated.
134	numRenegotiations int
135	// renegotiationExpectedToFail, if not zero, is the number of the
136	// renegotiation attempt that is expected to fail.
137	renegotiationExpectedToFail int
138	// checkRenegotiationError, if not nil, is called with any error
139	// arising from renegotiation. It can map expected errors to nil to
140	// ignore them.
141	checkRenegotiationError func(renegotiationNum int, err error) error
142}
143
144var defaultServerCommand = []string{"openssl", "s_server"}
145
146// connFromCommand starts the reference server process, connects to it and
147// returns a recordingConn for the connection. The stdin return value is an
148// opensslInput for the stdin of the child process. It must be closed before
149// Waiting for child.
150func (test *clientTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, stdin opensslInput, stdout *opensslOutputSink, err error) {
151	cert := testRSACertificate
152	if len(test.cert) > 0 {
153		cert = test.cert
154	}
155	certPath := tempFile(string(cert))
156	defer os.Remove(certPath)
157
158	var key interface{} = testRSAPrivateKey
159	if test.key != nil {
160		key = test.key
161	}
162	var pemType string
163	var derBytes []byte
164	switch key := key.(type) {
165	case *rsa.PrivateKey:
166		pemType = "RSA"
167		derBytes = x509.MarshalPKCS1PrivateKey(key)
168	case *ecdsa.PrivateKey:
169		pemType = "EC"
170		var err error
171		derBytes, err = x509.MarshalECPrivateKey(key)
172		if err != nil {
173			panic(err)
174		}
175	default:
176		panic("unknown key type")
177	}
178
179	var pemOut bytes.Buffer
180	pem.Encode(&pemOut, &pem.Block{Type: pemType + " PRIVATE KEY", Bytes: derBytes})
181
182	keyPath := tempFile(string(pemOut.Bytes()))
183	defer os.Remove(keyPath)
184
185	var command []string
186	if len(test.command) > 0 {
187		command = append(command, test.command...)
188	} else {
189		command = append(command, defaultServerCommand...)
190	}
191	command = append(command, "-cert", certPath, "-certform", "DER", "-key", keyPath)
192	// serverPort contains the port that OpenSSL will listen on. OpenSSL
193	// can't take "0" as an argument here so we have to pick a number and
194	// hope that it's not in use on the machine. Since this only occurs
195	// when -update is given and thus when there's a human watching the
196	// test, this isn't too bad.
197	const serverPort = 24323
198	command = append(command, "-accept", strconv.Itoa(serverPort))
199
200	if len(test.extensions) > 0 {
201		var serverInfo bytes.Buffer
202		for _, ext := range test.extensions {
203			pem.Encode(&serverInfo, &pem.Block{
204				Type:  fmt.Sprintf("SERVERINFO FOR EXTENSION %d", binary.BigEndian.Uint16(ext)),
205				Bytes: ext,
206			})
207		}
208		serverInfoPath := tempFile(serverInfo.String())
209		defer os.Remove(serverInfoPath)
210		command = append(command, "-serverinfo", serverInfoPath)
211	}
212
213	if test.numRenegotiations > 0 {
214		found := false
215		for _, flag := range command[1:] {
216			if flag == "-state" {
217				found = true
218				break
219			}
220		}
221
222		if !found {
223			panic("-state flag missing to OpenSSL. You need this if testing renegotiation")
224		}
225	}
226
227	cmd := exec.Command(command[0], command[1:]...)
228	stdin = opensslInput(make(chan opensslInputEvent))
229	cmd.Stdin = stdin
230	out := newOpensslOutputSink()
231	cmd.Stdout = out
232	cmd.Stderr = out
233	if err := cmd.Start(); err != nil {
234		return nil, nil, nil, nil, err
235	}
236
237	// OpenSSL does print an "ACCEPT" banner, but it does so *before*
238	// opening the listening socket, so we can't use that to wait until it
239	// has started listening. Thus we are forced to poll until we get a
240	// connection.
241	var tcpConn net.Conn
242	for i := uint(0); i < 5; i++ {
243		tcpConn, err = net.DialTCP("tcp", nil, &net.TCPAddr{
244			IP:   net.IPv4(127, 0, 0, 1),
245			Port: serverPort,
246		})
247		if err == nil {
248			break
249		}
250		time.Sleep((1 << i) * 5 * time.Millisecond)
251	}
252	if err != nil {
253		close(stdin)
254		out.WriteTo(os.Stdout)
255		cmd.Process.Kill()
256		return nil, nil, nil, nil, cmd.Wait()
257	}
258
259	record := &recordingConn{
260		Conn: tcpConn,
261	}
262
263	return record, cmd, stdin, out, nil
264}
265
266func (test *clientTest) dataPath() string {
267	return filepath.Join("testdata", "Client-"+test.name)
268}
269
270func (test *clientTest) loadData() (flows [][]byte, err error) {
271	in, err := os.Open(test.dataPath())
272	if err != nil {
273		return nil, err
274	}
275	defer in.Close()
276	return parseTestData(in)
277}
278
279func (test *clientTest) run(t *testing.T, write bool) {
280	checkOpenSSLVersion(t)
281
282	var clientConn, serverConn net.Conn
283	var recordingConn *recordingConn
284	var childProcess *exec.Cmd
285	var stdin opensslInput
286	var stdout *opensslOutputSink
287
288	if write {
289		var err error
290		recordingConn, childProcess, stdin, stdout, err = test.connFromCommand()
291		if err != nil {
292			t.Fatalf("Failed to start subcommand: %s", err)
293		}
294		clientConn = recordingConn
295	} else {
296		clientConn, serverConn = net.Pipe()
297	}
298
299	config := test.config
300	if config == nil {
301		config = testConfig
302	}
303	client := Client(clientConn, config)
304
305	doneChan := make(chan bool)
306	go func() {
307		defer func() { doneChan <- true }()
308		defer clientConn.Close()
309		defer client.Close()
310
311		if _, err := client.Write([]byte("hello\n")); err != nil {
312			t.Errorf("Client.Write failed: %s", err)
313			return
314		}
315
316		for i := 1; i <= test.numRenegotiations; i++ {
317			// The initial handshake will generate a
318			// handshakeComplete signal which needs to be quashed.
319			if i == 1 && write {
320				<-stdout.handshakeComplete
321			}
322
323			// OpenSSL will try to interleave application data and
324			// a renegotiation if we send both concurrently.
325			// Therefore: ask OpensSSL to start a renegotiation, run
326			// a goroutine to call client.Read and thus process the
327			// renegotiation request, watch for OpenSSL's stdout to
328			// indicate that the handshake is complete and,
329			// finally, have OpenSSL write something to cause
330			// client.Read to complete.
331			if write {
332				stdin <- opensslRenegotiate
333			}
334
335			signalChan := make(chan struct{})
336
337			go func() {
338				defer func() { signalChan <- struct{}{} }()
339
340				buf := make([]byte, 256)
341				n, err := client.Read(buf)
342
343				if test.checkRenegotiationError != nil {
344					newErr := test.checkRenegotiationError(i, err)
345					if err != nil && newErr == nil {
346						return
347					}
348					err = newErr
349				}
350
351				if err != nil {
352					t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
353					return
354				}
355
356				buf = buf[:n]
357				if !bytes.Equal([]byte(opensslSentinel), buf) {
358					t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
359				}
360
361				if expected := i + 1; client.handshakes != expected {
362					t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
363				}
364			}()
365
366			if write && test.renegotiationExpectedToFail != i {
367				<-stdout.handshakeComplete
368				stdin <- opensslSendSentinel
369			}
370			<-signalChan
371		}
372
373		if test.validate != nil {
374			if err := test.validate(client.ConnectionState()); err != nil {
375				t.Errorf("validate callback returned error: %s", err)
376			}
377		}
378	}()
379
380	if !write {
381		flows, err := test.loadData()
382		if err != nil {
383			t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
384		}
385		for i, b := range flows {
386			if i%2 == 1 {
387				serverConn.Write(b)
388				continue
389			}
390			bb := make([]byte, len(b))
391			_, err := io.ReadFull(serverConn, bb)
392			if err != nil {
393				t.Fatalf("%s #%d: %s", test.name, i, err)
394			}
395			if !bytes.Equal(b, bb) {
396				t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i, bb, b)
397			}
398		}
399		serverConn.Close()
400	}
401
402	<-doneChan
403
404	if write {
405		path := test.dataPath()
406		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
407		if err != nil {
408			t.Fatalf("Failed to create output file: %s", err)
409		}
410		defer out.Close()
411		recordingConn.Close()
412		close(stdin)
413		childProcess.Process.Kill()
414		childProcess.Wait()
415		if len(recordingConn.flows) < 3 {
416			os.Stdout.Write(childProcess.Stdout.(*opensslOutputSink).all)
417			t.Fatalf("Client connection didn't work")
418		}
419		recordingConn.WriteTo(out)
420		fmt.Printf("Wrote %s\n", path)
421	}
422}
423
424var (
425	didParMu sync.Mutex
426	didPar   = map[*testing.T]bool{}
427)
428
429// setParallel calls t.Parallel once. If you call it twice, it would
430// panic.
431func setParallel(t *testing.T) {
432	didParMu.Lock()
433	v := didPar[t]
434	didPar[t] = true
435	didParMu.Unlock()
436	if !v {
437		t.Parallel()
438	}
439}
440
441func runClientTestForVersion(t *testing.T, template *clientTest, prefix, option string) {
442	setParallel(t)
443
444	test := *template
445	test.name = prefix + test.name
446	if len(test.command) == 0 {
447		test.command = defaultClientCommand
448	}
449	test.command = append([]string(nil), test.command...)
450	test.command = append(test.command, option)
451	test.run(t, *update)
452}
453
454func runClientTestTLS10(t *testing.T, template *clientTest) {
455	runClientTestForVersion(t, template, "TLSv10-", "-tls1")
456}
457
458func runClientTestTLS11(t *testing.T, template *clientTest) {
459	runClientTestForVersion(t, template, "TLSv11-", "-tls1_1")
460}
461
462func runClientTestTLS12(t *testing.T, template *clientTest) {
463	runClientTestForVersion(t, template, "TLSv12-", "-tls1_2")
464}
465
466func TestHandshakeClientRSARC4(t *testing.T) {
467	test := &clientTest{
468		name:    "RSA-RC4",
469		command: []string{"openssl", "s_server", "-cipher", "RC4-SHA"},
470	}
471	runClientTestTLS10(t, test)
472	runClientTestTLS11(t, test)
473	runClientTestTLS12(t, test)
474}
475
476func TestHandshakeClientRSAAES128GCM(t *testing.T) {
477	test := &clientTest{
478		name:    "AES128-GCM-SHA256",
479		command: []string{"openssl", "s_server", "-cipher", "AES128-GCM-SHA256"},
480	}
481	runClientTestTLS12(t, test)
482}
483
484func TestHandshakeClientRSAAES256GCM(t *testing.T) {
485	test := &clientTest{
486		name:    "AES256-GCM-SHA384",
487		command: []string{"openssl", "s_server", "-cipher", "AES256-GCM-SHA384"},
488	}
489	runClientTestTLS12(t, test)
490}
491
492func TestHandshakeClientECDHERSAAES(t *testing.T) {
493	test := &clientTest{
494		name:    "ECDHE-RSA-AES",
495		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA"},
496	}
497	runClientTestTLS10(t, test)
498	runClientTestTLS11(t, test)
499	runClientTestTLS12(t, test)
500}
501
502func TestHandshakeClientECDHEECDSAAES(t *testing.T) {
503	test := &clientTest{
504		name:    "ECDHE-ECDSA-AES",
505		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA"},
506		cert:    testECDSACertificate,
507		key:     testECDSAPrivateKey,
508	}
509	runClientTestTLS10(t, test)
510	runClientTestTLS11(t, test)
511	runClientTestTLS12(t, test)
512}
513
514func TestHandshakeClientECDHEECDSAAESGCM(t *testing.T) {
515	test := &clientTest{
516		name:    "ECDHE-ECDSA-AES-GCM",
517		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-GCM-SHA256"},
518		cert:    testECDSACertificate,
519		key:     testECDSAPrivateKey,
520	}
521	runClientTestTLS12(t, test)
522}
523
524func TestHandshakeClientAES256GCMSHA384(t *testing.T) {
525	test := &clientTest{
526		name:    "ECDHE-ECDSA-AES256-GCM-SHA384",
527		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES256-GCM-SHA384"},
528		cert:    testECDSACertificate,
529		key:     testECDSAPrivateKey,
530	}
531	runClientTestTLS12(t, test)
532}
533
534func TestHandshakeClientAES128CBCSHA256(t *testing.T) {
535	test := &clientTest{
536		name:    "AES128-SHA256",
537		command: []string{"openssl", "s_server", "-cipher", "AES128-SHA256"},
538	}
539	runClientTestTLS12(t, test)
540}
541
542func TestHandshakeClientECDHERSAAES128CBCSHA256(t *testing.T) {
543	test := &clientTest{
544		name:    "ECDHE-RSA-AES128-SHA256",
545		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-SHA256"},
546	}
547	runClientTestTLS12(t, test)
548}
549
550func TestHandshakeClientECDHEECDSAAES128CBCSHA256(t *testing.T) {
551	test := &clientTest{
552		name:    "ECDHE-ECDSA-AES128-SHA256",
553		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA256"},
554		cert:    testECDSACertificate,
555		key:     testECDSAPrivateKey,
556	}
557	runClientTestTLS12(t, test)
558}
559
560func TestHandshakeClientX25519(t *testing.T) {
561	config := testConfig.Clone()
562	config.CurvePreferences = []CurveID{X25519}
563
564	test := &clientTest{
565		name:    "X25519-ECDHE-RSA-AES-GCM",
566		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"},
567		config:  config,
568	}
569
570	runClientTestTLS12(t, test)
571}
572
573func TestHandshakeClientECDHERSAChaCha20(t *testing.T) {
574	config := testConfig.Clone()
575	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}
576
577	test := &clientTest{
578		name:    "ECDHE-RSA-CHACHA20-POLY1305",
579		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-CHACHA20-POLY1305"},
580		config:  config,
581	}
582
583	runClientTestTLS12(t, test)
584}
585
586func TestHandshakeClientECDHEECDSAChaCha20(t *testing.T) {
587	config := testConfig.Clone()
588	config.CipherSuites = []uint16{TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305}
589
590	test := &clientTest{
591		name:    "ECDHE-ECDSA-CHACHA20-POLY1305",
592		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-CHACHA20-POLY1305"},
593		config:  config,
594		cert:    testECDSACertificate,
595		key:     testECDSAPrivateKey,
596	}
597
598	runClientTestTLS12(t, test)
599}
600
601func TestHandshakeClientCertRSA(t *testing.T) {
602	config := testConfig.Clone()
603	cert, _ := X509KeyPair([]byte(clientCertificatePEM), []byte(clientKeyPEM))
604	config.Certificates = []Certificate{cert}
605
606	test := &clientTest{
607		name:    "ClientCert-RSA-RSA",
608		command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
609		config:  config,
610	}
611
612	runClientTestTLS10(t, test)
613	runClientTestTLS12(t, test)
614
615	test = &clientTest{
616		name:    "ClientCert-RSA-ECDSA",
617		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
618		config:  config,
619		cert:    testECDSACertificate,
620		key:     testECDSAPrivateKey,
621	}
622
623	runClientTestTLS10(t, test)
624	runClientTestTLS12(t, test)
625
626	test = &clientTest{
627		name:    "ClientCert-RSA-AES256-GCM-SHA384",
628		command: []string{"openssl", "s_server", "-cipher", "ECDHE-RSA-AES256-GCM-SHA384", "-verify", "1"},
629		config:  config,
630		cert:    testRSACertificate,
631		key:     testRSAPrivateKey,
632	}
633
634	runClientTestTLS12(t, test)
635}
636
637func TestHandshakeClientCertECDSA(t *testing.T) {
638	config := testConfig.Clone()
639	cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
640	config.Certificates = []Certificate{cert}
641
642	test := &clientTest{
643		name:    "ClientCert-ECDSA-RSA",
644		command: []string{"openssl", "s_server", "-cipher", "AES128", "-verify", "1"},
645		config:  config,
646	}
647
648	runClientTestTLS10(t, test)
649	runClientTestTLS12(t, test)
650
651	test = &clientTest{
652		name:    "ClientCert-ECDSA-ECDSA",
653		command: []string{"openssl", "s_server", "-cipher", "ECDHE-ECDSA-AES128-SHA", "-verify", "1"},
654		config:  config,
655		cert:    testECDSACertificate,
656		key:     testECDSAPrivateKey,
657	}
658
659	runClientTestTLS10(t, test)
660	runClientTestTLS12(t, test)
661}
662
663func TestClientResumption(t *testing.T) {
664	serverConfig := &Config{
665		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
666		Certificates: testConfig.Certificates,
667	}
668
669	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
670	if err != nil {
671		panic(err)
672	}
673
674	rootCAs := x509.NewCertPool()
675	rootCAs.AddCert(issuer)
676
677	clientConfig := &Config{
678		CipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
679		ClientSessionCache: NewLRUClientSessionCache(32),
680		RootCAs:            rootCAs,
681		ServerName:         "example.golang",
682	}
683
684	testResumeState := func(test string, didResume bool) {
685		_, hs, err := testHandshake(clientConfig, serverConfig)
686		if err != nil {
687			t.Fatalf("%s: handshake failed: %s", test, err)
688		}
689		if hs.DidResume != didResume {
690			t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
691		}
692		if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) {
693			t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains)
694		}
695	}
696
697	getTicket := func() []byte {
698		return clientConfig.ClientSessionCache.(*lruSessionCache).q.Front().Value.(*lruSessionCacheEntry).state.sessionTicket
699	}
700	randomKey := func() [32]byte {
701		var k [32]byte
702		if _, err := io.ReadFull(serverConfig.rand(), k[:]); err != nil {
703			t.Fatalf("Failed to read new SessionTicketKey: %s", err)
704		}
705		return k
706	}
707
708	testResumeState("Handshake", false)
709	ticket := getTicket()
710	testResumeState("Resume", true)
711	if !bytes.Equal(ticket, getTicket()) {
712		t.Fatal("first ticket doesn't match ticket after resumption")
713	}
714
715	key1 := randomKey()
716	serverConfig.SetSessionTicketKeys([][32]byte{key1})
717
718	testResumeState("InvalidSessionTicketKey", false)
719	testResumeState("ResumeAfterInvalidSessionTicketKey", true)
720
721	key2 := randomKey()
722	serverConfig.SetSessionTicketKeys([][32]byte{key2, key1})
723	ticket = getTicket()
724	testResumeState("KeyChange", true)
725	if bytes.Equal(ticket, getTicket()) {
726		t.Fatal("new ticket wasn't included while resuming")
727	}
728	testResumeState("KeyChangeFinish", true)
729
730	// Reset serverConfig to ensure that calling SetSessionTicketKeys
731	// before the serverConfig is used works.
732	serverConfig = &Config{
733		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
734		Certificates: testConfig.Certificates,
735	}
736	serverConfig.SetSessionTicketKeys([][32]byte{key2})
737
738	testResumeState("FreshConfig", true)
739
740	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
741	testResumeState("DifferentCipherSuite", false)
742	testResumeState("DifferentCipherSuiteRecovers", true)
743
744	clientConfig.ClientSessionCache = nil
745	testResumeState("WithoutSessionCache", false)
746}
747
748func TestLRUClientSessionCache(t *testing.T) {
749	// Initialize cache of capacity 4.
750	cache := NewLRUClientSessionCache(4)
751	cs := make([]ClientSessionState, 6)
752	keys := []string{"0", "1", "2", "3", "4", "5", "6"}
753
754	// Add 4 entries to the cache and look them up.
755	for i := 0; i < 4; i++ {
756		cache.Put(keys[i], &cs[i])
757	}
758	for i := 0; i < 4; i++ {
759		if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
760			t.Fatalf("session cache failed lookup for added key: %s", keys[i])
761		}
762	}
763
764	// Add 2 more entries to the cache. First 2 should be evicted.
765	for i := 4; i < 6; i++ {
766		cache.Put(keys[i], &cs[i])
767	}
768	for i := 0; i < 2; i++ {
769		if s, ok := cache.Get(keys[i]); ok || s != nil {
770			t.Fatalf("session cache should have evicted key: %s", keys[i])
771		}
772	}
773
774	// Touch entry 2. LRU should evict 3 next.
775	cache.Get(keys[2])
776	cache.Put(keys[0], &cs[0])
777	if s, ok := cache.Get(keys[3]); ok || s != nil {
778		t.Fatalf("session cache should have evicted key 3")
779	}
780
781	// Update entry 0 in place.
782	cache.Put(keys[0], &cs[3])
783	if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
784		t.Fatalf("session cache failed update for key 0")
785	}
786
787	// Adding a nil entry is valid.
788	cache.Put(keys[0], nil)
789	if s, ok := cache.Get(keys[0]); !ok || s != nil {
790		t.Fatalf("failed to add nil entry to cache")
791	}
792}
793
794func TestKeyLog(t *testing.T) {
795	var serverBuf, clientBuf bytes.Buffer
796
797	clientConfig := testConfig.Clone()
798	clientConfig.KeyLogWriter = &clientBuf
799
800	serverConfig := testConfig.Clone()
801	serverConfig.KeyLogWriter = &serverBuf
802
803	c, s := net.Pipe()
804	done := make(chan bool)
805
806	go func() {
807		defer close(done)
808
809		if err := Server(s, serverConfig).Handshake(); err != nil {
810			t.Errorf("server: %s", err)
811			return
812		}
813		s.Close()
814	}()
815
816	if err := Client(c, clientConfig).Handshake(); err != nil {
817		t.Fatalf("client: %s", err)
818	}
819
820	c.Close()
821	<-done
822
823	checkKeylogLine := func(side, loggedLine string) {
824		if len(loggedLine) == 0 {
825			t.Fatalf("%s: no keylog line was produced", side)
826		}
827		const expectedLen = 13 /* "CLIENT_RANDOM" */ +
828			1 /* space */ +
829			32*2 /* hex client nonce */ +
830			1 /* space */ +
831			48*2 /* hex master secret */ +
832			1 /* new line */
833		if len(loggedLine) != expectedLen {
834			t.Fatalf("%s: keylog line has incorrect length (want %d, got %d): %q", side, expectedLen, len(loggedLine), loggedLine)
835		}
836		if !strings.HasPrefix(loggedLine, "CLIENT_RANDOM "+strings.Repeat("0", 64)+" ") {
837			t.Fatalf("%s: keylog line has incorrect structure or nonce: %q", side, loggedLine)
838		}
839	}
840
841	checkKeylogLine("client", string(clientBuf.Bytes()))
842	checkKeylogLine("server", string(serverBuf.Bytes()))
843}
844
845func TestHandshakeClientALPNMatch(t *testing.T) {
846	config := testConfig.Clone()
847	config.NextProtos = []string{"proto2", "proto1"}
848
849	test := &clientTest{
850		name: "ALPN",
851		// Note that this needs OpenSSL 1.0.2 because that is the first
852		// version that supports the -alpn flag.
853		command: []string{"openssl", "s_server", "-alpn", "proto1,proto2"},
854		config:  config,
855		validate: func(state ConnectionState) error {
856			// The server's preferences should override the client.
857			if state.NegotiatedProtocol != "proto1" {
858				return fmt.Errorf("Got protocol %q, wanted proto1", state.NegotiatedProtocol)
859			}
860			return nil
861		},
862	}
863	runClientTestTLS12(t, test)
864}
865
866// sctsBase64 contains data from `openssl s_client -serverinfo 18 -connect ritter.vg:443`
867const sctsBase64 = "ABIBaQFnAHUApLkJkLQYWBSHuxOizGdwCjw1mAT5G9+443fNDsgN3BAAAAFHl5nuFgAABAMARjBEAiAcS4JdlW5nW9sElUv2zvQyPoZ6ejKrGGB03gjaBZFMLwIgc1Qbbn+hsH0RvObzhS+XZhr3iuQQJY8S9G85D9KeGPAAdgBo9pj4H2SCvjqM7rkoHUz8cVFdZ5PURNEKZ6y7T0/7xAAAAUeX4bVwAAAEAwBHMEUCIDIhFDgG2HIuADBkGuLobU5a4dlCHoJLliWJ1SYT05z6AiEAjxIoZFFPRNWMGGIjskOTMwXzQ1Wh2e7NxXE1kd1J0QsAdgDuS723dc5guuFCaR+r4Z5mow9+X7By2IMAxHuJeqj9ywAAAUhcZIqHAAAEAwBHMEUCICmJ1rBT09LpkbzxtUC+Hi7nXLR0J+2PmwLp+sJMuqK+AiEAr0NkUnEVKVhAkccIFpYDqHOlZaBsuEhWWrYpg2RtKp0="
868
869func TestHandshakClientSCTs(t *testing.T) {
870	config := testConfig.Clone()
871
872	scts, err := base64.StdEncoding.DecodeString(sctsBase64)
873	if err != nil {
874		t.Fatal(err)
875	}
876
877	test := &clientTest{
878		name: "SCT",
879		// Note that this needs OpenSSL 1.0.2 because that is the first
880		// version that supports the -serverinfo flag.
881		command:    []string{"openssl", "s_server"},
882		config:     config,
883		extensions: [][]byte{scts},
884		validate: func(state ConnectionState) error {
885			expectedSCTs := [][]byte{
886				scts[8:125],
887				scts[127:245],
888				scts[247:],
889			}
890			if n := len(state.SignedCertificateTimestamps); n != len(expectedSCTs) {
891				return fmt.Errorf("Got %d scts, wanted %d", n, len(expectedSCTs))
892			}
893			for i, expected := range expectedSCTs {
894				if sct := state.SignedCertificateTimestamps[i]; !bytes.Equal(sct, expected) {
895					return fmt.Errorf("SCT #%d contained %x, expected %x", i, sct, expected)
896				}
897			}
898			return nil
899		},
900	}
901	runClientTestTLS12(t, test)
902}
903
904func TestRenegotiationRejected(t *testing.T) {
905	config := testConfig.Clone()
906	test := &clientTest{
907		name:                        "RenegotiationRejected",
908		command:                     []string{"openssl", "s_server", "-state"},
909		config:                      config,
910		numRenegotiations:           1,
911		renegotiationExpectedToFail: 1,
912		checkRenegotiationError: func(renegotiationNum int, err error) error {
913			if err == nil {
914				return errors.New("expected error from renegotiation but got nil")
915			}
916			if !strings.Contains(err.Error(), "no renegotiation") {
917				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
918			}
919			return nil
920		},
921	}
922
923	runClientTestTLS12(t, test)
924}
925
926func TestRenegotiateOnce(t *testing.T) {
927	config := testConfig.Clone()
928	config.Renegotiation = RenegotiateOnceAsClient
929
930	test := &clientTest{
931		name:              "RenegotiateOnce",
932		command:           []string{"openssl", "s_server", "-state"},
933		config:            config,
934		numRenegotiations: 1,
935	}
936
937	runClientTestTLS12(t, test)
938}
939
940func TestRenegotiateTwice(t *testing.T) {
941	config := testConfig.Clone()
942	config.Renegotiation = RenegotiateFreelyAsClient
943
944	test := &clientTest{
945		name:              "RenegotiateTwice",
946		command:           []string{"openssl", "s_server", "-state"},
947		config:            config,
948		numRenegotiations: 2,
949	}
950
951	runClientTestTLS12(t, test)
952}
953
954func TestRenegotiateTwiceRejected(t *testing.T) {
955	config := testConfig.Clone()
956	config.Renegotiation = RenegotiateOnceAsClient
957
958	test := &clientTest{
959		name:                        "RenegotiateTwiceRejected",
960		command:                     []string{"openssl", "s_server", "-state"},
961		config:                      config,
962		numRenegotiations:           2,
963		renegotiationExpectedToFail: 2,
964		checkRenegotiationError: func(renegotiationNum int, err error) error {
965			if renegotiationNum == 1 {
966				return err
967			}
968
969			if err == nil {
970				return errors.New("expected error from renegotiation but got nil")
971			}
972			if !strings.Contains(err.Error(), "no renegotiation") {
973				return fmt.Errorf("expected renegotiation to be rejected but got %q", err)
974			}
975			return nil
976		},
977	}
978
979	runClientTestTLS12(t, test)
980}
981
982var hostnameInSNITests = []struct {
983	in, out string
984}{
985	// Opaque string
986	{"", ""},
987	{"localhost", "localhost"},
988	{"foo, bar, baz and qux", "foo, bar, baz and qux"},
989
990	// DNS hostname
991	{"golang.org", "golang.org"},
992	{"golang.org.", "golang.org"},
993
994	// Literal IPv4 address
995	{"1.2.3.4", ""},
996
997	// Literal IPv6 address
998	{"::1", ""},
999	{"::1%lo0", ""}, // with zone identifier
1000	{"[::1]", ""},   // as per RFC 5952 we allow the [] style as IPv6 literal
1001	{"[::1%lo0]", ""},
1002}
1003
1004func TestHostnameInSNI(t *testing.T) {
1005	for _, tt := range hostnameInSNITests {
1006		c, s := net.Pipe()
1007
1008		go func(host string) {
1009			Client(c, &Config{ServerName: host, InsecureSkipVerify: true}).Handshake()
1010		}(tt.in)
1011
1012		var header [5]byte
1013		if _, err := io.ReadFull(s, header[:]); err != nil {
1014			t.Fatal(err)
1015		}
1016		recordLen := int(header[3])<<8 | int(header[4])
1017
1018		record := make([]byte, recordLen)
1019		if _, err := io.ReadFull(s, record[:]); err != nil {
1020			t.Fatal(err)
1021		}
1022
1023		c.Close()
1024		s.Close()
1025
1026		var m clientHelloMsg
1027		if !m.unmarshal(record) {
1028			t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
1029			continue
1030		}
1031		if tt.in != tt.out && m.serverName == tt.in {
1032			t.Errorf("prohibited %q found in ClientHello: %x", tt.in, record)
1033		}
1034		if m.serverName != tt.out {
1035			t.Errorf("expected %q not found in ClientHello: %x", tt.out, record)
1036		}
1037	}
1038}
1039
1040func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
1041	// This checks that the server can't select a cipher suite that the
1042	// client didn't offer. See #13174.
1043
1044	c, s := net.Pipe()
1045	errChan := make(chan error, 1)
1046
1047	go func() {
1048		client := Client(c, &Config{
1049			ServerName:   "foo",
1050			CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256},
1051		})
1052		errChan <- client.Handshake()
1053	}()
1054
1055	var header [5]byte
1056	if _, err := io.ReadFull(s, header[:]); err != nil {
1057		t.Fatal(err)
1058	}
1059	recordLen := int(header[3])<<8 | int(header[4])
1060
1061	record := make([]byte, recordLen)
1062	if _, err := io.ReadFull(s, record); err != nil {
1063		t.Fatal(err)
1064	}
1065
1066	// Create a ServerHello that selects a different cipher suite than the
1067	// sole one that the client offered.
1068	serverHello := &serverHelloMsg{
1069		vers:        VersionTLS12,
1070		random:      make([]byte, 32),
1071		cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
1072	}
1073	serverHelloBytes := serverHello.marshal()
1074
1075	s.Write([]byte{
1076		byte(recordTypeHandshake),
1077		byte(VersionTLS12 >> 8),
1078		byte(VersionTLS12 & 0xff),
1079		byte(len(serverHelloBytes) >> 8),
1080		byte(len(serverHelloBytes)),
1081	})
1082	s.Write(serverHelloBytes)
1083	s.Close()
1084
1085	if err := <-errChan; !strings.Contains(err.Error(), "unconfigured cipher") {
1086		t.Fatalf("Expected error about unconfigured cipher suite but got %q", err)
1087	}
1088}
1089
1090func TestVerifyPeerCertificate(t *testing.T) {
1091	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1092	if err != nil {
1093		panic(err)
1094	}
1095
1096	rootCAs := x509.NewCertPool()
1097	rootCAs.AddCert(issuer)
1098
1099	now := func() time.Time { return time.Unix(1476984729, 0) }
1100
1101	sentinelErr := errors.New("TestVerifyPeerCertificate")
1102
1103	verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1104		if l := len(rawCerts); l != 1 {
1105			return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1106		}
1107		if len(validatedChains) == 0 {
1108			return errors.New("got len(validatedChains) = 0, wanted non-zero")
1109		}
1110		*called = true
1111		return nil
1112	}
1113
1114	tests := []struct {
1115		configureServer func(*Config, *bool)
1116		configureClient func(*Config, *bool)
1117		validate        func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
1118	}{
1119		{
1120			configureServer: func(config *Config, called *bool) {
1121				config.InsecureSkipVerify = false
1122				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1123					return verifyCallback(called, rawCerts, validatedChains)
1124				}
1125			},
1126			configureClient: func(config *Config, called *bool) {
1127				config.InsecureSkipVerify = false
1128				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1129					return verifyCallback(called, rawCerts, validatedChains)
1130				}
1131			},
1132			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1133				if clientErr != nil {
1134					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1135				}
1136				if serverErr != nil {
1137					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1138				}
1139				if !clientCalled {
1140					t.Errorf("test[%d]: client did not call callback", testNo)
1141				}
1142				if !serverCalled {
1143					t.Errorf("test[%d]: server did not call callback", testNo)
1144				}
1145			},
1146		},
1147		{
1148			configureServer: func(config *Config, called *bool) {
1149				config.InsecureSkipVerify = false
1150				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1151					return sentinelErr
1152				}
1153			},
1154			configureClient: func(config *Config, called *bool) {
1155				config.VerifyPeerCertificate = nil
1156			},
1157			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1158				if serverErr != sentinelErr {
1159					t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
1160				}
1161			},
1162		},
1163		{
1164			configureServer: func(config *Config, called *bool) {
1165				config.InsecureSkipVerify = false
1166			},
1167			configureClient: func(config *Config, called *bool) {
1168				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1169					return sentinelErr
1170				}
1171			},
1172			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1173				if clientErr != sentinelErr {
1174					t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
1175				}
1176			},
1177		},
1178		{
1179			configureServer: func(config *Config, called *bool) {
1180				config.InsecureSkipVerify = false
1181			},
1182			configureClient: func(config *Config, called *bool) {
1183				config.InsecureSkipVerify = true
1184				config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
1185					if l := len(rawCerts); l != 1 {
1186						return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
1187					}
1188					// With InsecureSkipVerify set, this
1189					// callback should still be called but
1190					// validatedChains must be empty.
1191					if l := len(validatedChains); l != 0 {
1192						return errors.New("got len(validatedChains) = 0, wanted zero")
1193					}
1194					*called = true
1195					return nil
1196				}
1197			},
1198			validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
1199				if clientErr != nil {
1200					t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr)
1201				}
1202				if serverErr != nil {
1203					t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr)
1204				}
1205				if !clientCalled {
1206					t.Errorf("test[%d]: client did not call callback", testNo)
1207				}
1208			},
1209		},
1210	}
1211
1212	for i, test := range tests {
1213		c, s := net.Pipe()
1214		done := make(chan error)
1215
1216		var clientCalled, serverCalled bool
1217
1218		go func() {
1219			config := testConfig.Clone()
1220			config.ServerName = "example.golang"
1221			config.ClientAuth = RequireAndVerifyClientCert
1222			config.ClientCAs = rootCAs
1223			config.Time = now
1224			test.configureServer(config, &serverCalled)
1225
1226			err = Server(s, config).Handshake()
1227			s.Close()
1228			done <- err
1229		}()
1230
1231		config := testConfig.Clone()
1232		config.ServerName = "example.golang"
1233		config.RootCAs = rootCAs
1234		config.Time = now
1235		test.configureClient(config, &clientCalled)
1236		clientErr := Client(c, config).Handshake()
1237		c.Close()
1238		serverErr := <-done
1239
1240		test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
1241	}
1242}
1243
1244// brokenConn wraps a net.Conn and causes all Writes after a certain number to
1245// fail with brokenConnErr.
1246type brokenConn struct {
1247	net.Conn
1248
1249	// breakAfter is the number of successful writes that will be allowed
1250	// before all subsequent writes fail.
1251	breakAfter int
1252
1253	// numWrites is the number of writes that have been done.
1254	numWrites int
1255}
1256
1257// brokenConnErr is the error that brokenConn returns once exhausted.
1258var brokenConnErr = errors.New("too many writes to brokenConn")
1259
1260func (b *brokenConn) Write(data []byte) (int, error) {
1261	if b.numWrites >= b.breakAfter {
1262		return 0, brokenConnErr
1263	}
1264
1265	b.numWrites++
1266	return b.Conn.Write(data)
1267}
1268
1269func TestFailedWrite(t *testing.T) {
1270	// Test that a write error during the handshake is returned.
1271	for _, breakAfter := range []int{0, 1} {
1272		c, s := net.Pipe()
1273		done := make(chan bool)
1274
1275		go func() {
1276			Server(s, testConfig).Handshake()
1277			s.Close()
1278			done <- true
1279		}()
1280
1281		brokenC := &brokenConn{Conn: c, breakAfter: breakAfter}
1282		err := Client(brokenC, testConfig).Handshake()
1283		if err != brokenConnErr {
1284			t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err)
1285		}
1286		brokenC.Close()
1287
1288		<-done
1289	}
1290}
1291
1292// writeCountingConn wraps a net.Conn and counts the number of Write calls.
1293type writeCountingConn struct {
1294	net.Conn
1295
1296	// numWrites is the number of writes that have been done.
1297	numWrites int
1298}
1299
1300func (wcc *writeCountingConn) Write(data []byte) (int, error) {
1301	wcc.numWrites++
1302	return wcc.Conn.Write(data)
1303}
1304
1305func TestBuffering(t *testing.T) {
1306	c, s := net.Pipe()
1307	done := make(chan bool)
1308
1309	clientWCC := &writeCountingConn{Conn: c}
1310	serverWCC := &writeCountingConn{Conn: s}
1311
1312	go func() {
1313		Server(serverWCC, testConfig).Handshake()
1314		serverWCC.Close()
1315		done <- true
1316	}()
1317
1318	err := Client(clientWCC, testConfig).Handshake()
1319	if err != nil {
1320		t.Fatal(err)
1321	}
1322	clientWCC.Close()
1323	<-done
1324
1325	if n := clientWCC.numWrites; n != 2 {
1326		t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
1327	}
1328
1329	if n := serverWCC.numWrites; n != 2 {
1330		t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
1331	}
1332}
1333
1334func TestAlertFlushing(t *testing.T) {
1335	c, s := net.Pipe()
1336	done := make(chan bool)
1337
1338	clientWCC := &writeCountingConn{Conn: c}
1339	serverWCC := &writeCountingConn{Conn: s}
1340
1341	serverConfig := testConfig.Clone()
1342
1343	// Cause a signature-time error
1344	brokenKey := rsa.PrivateKey{PublicKey: testRSAPrivateKey.PublicKey}
1345	brokenKey.D = big.NewInt(42)
1346	serverConfig.Certificates = []Certificate{{
1347		Certificate: [][]byte{testRSACertificate},
1348		PrivateKey:  &brokenKey,
1349	}}
1350
1351	go func() {
1352		Server(serverWCC, serverConfig).Handshake()
1353		serverWCC.Close()
1354		done <- true
1355	}()
1356
1357	err := Client(clientWCC, testConfig).Handshake()
1358	if err == nil {
1359		t.Fatal("client unexpectedly returned no error")
1360	}
1361
1362	const expectedError = "remote error: tls: handshake failure"
1363	if e := err.Error(); !strings.Contains(e, expectedError) {
1364		t.Fatalf("expected to find %q in error but error was %q", expectedError, e)
1365	}
1366	clientWCC.Close()
1367	<-done
1368
1369	if n := clientWCC.numWrites; n != 1 {
1370		t.Errorf("expected client handshake to complete with one write, but saw %d", n)
1371	}
1372
1373	if n := serverWCC.numWrites; n != 1 {
1374		t.Errorf("expected server handshake to complete with one write, but saw %d", n)
1375	}
1376}
1377
1378func TestHandshakeRace(t *testing.T) {
1379	t.Parallel()
1380	// This test races a Read and Write to try and complete a handshake in
1381	// order to provide some evidence that there are no races or deadlocks
1382	// in the handshake locking.
1383	for i := 0; i < 32; i++ {
1384		c, s := net.Pipe()
1385
1386		go func() {
1387			server := Server(s, testConfig)
1388			if err := server.Handshake(); err != nil {
1389				panic(err)
1390			}
1391
1392			var request [1]byte
1393			if n, err := server.Read(request[:]); err != nil || n != 1 {
1394				panic(err)
1395			}
1396
1397			server.Write(request[:])
1398			server.Close()
1399		}()
1400
1401		startWrite := make(chan struct{})
1402		startRead := make(chan struct{})
1403		readDone := make(chan struct{})
1404
1405		client := Client(c, testConfig)
1406		go func() {
1407			<-startWrite
1408			var request [1]byte
1409			client.Write(request[:])
1410		}()
1411
1412		go func() {
1413			<-startRead
1414			var reply [1]byte
1415			if n, err := client.Read(reply[:]); err != nil || n != 1 {
1416				panic(err)
1417			}
1418			c.Close()
1419			readDone <- struct{}{}
1420		}()
1421
1422		if i&1 == 1 {
1423			startWrite <- struct{}{}
1424			startRead <- struct{}{}
1425		} else {
1426			startRead <- struct{}{}
1427			startWrite <- struct{}{}
1428		}
1429		<-readDone
1430	}
1431}
1432
1433func TestTLS11SignatureSchemes(t *testing.T) {
1434	expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
1435	if expected != len(tls11SignatureSchemes) {
1436		t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
1437	}
1438}
1439
1440var getClientCertificateTests = []struct {
1441	setup               func(*Config)
1442	expectedClientError string
1443	verify              func(*testing.T, int, *ConnectionState)
1444}{
1445	{
1446		func(clientConfig *Config) {
1447			// Returning a Certificate with no certificate data
1448			// should result in an empty message being sent to the
1449			// server.
1450			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1451				if len(cri.SignatureSchemes) == 0 {
1452					panic("empty SignatureSchemes")
1453				}
1454				return new(Certificate), nil
1455			}
1456		},
1457		"",
1458		func(t *testing.T, testNum int, cs *ConnectionState) {
1459			if l := len(cs.PeerCertificates); l != 0 {
1460				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
1461			}
1462		},
1463	},
1464	{
1465		func(clientConfig *Config) {
1466			// With TLS 1.1, the SignatureSchemes should be
1467			// synthesised from the supported certificate types.
1468			clientConfig.MaxVersion = VersionTLS11
1469			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1470				if len(cri.SignatureSchemes) == 0 {
1471					panic("empty SignatureSchemes")
1472				}
1473				return new(Certificate), nil
1474			}
1475		},
1476		"",
1477		func(t *testing.T, testNum int, cs *ConnectionState) {
1478			if l := len(cs.PeerCertificates); l != 0 {
1479				t.Errorf("#%d: expected no certificates but got %d", testNum, l)
1480			}
1481		},
1482	},
1483	{
1484		func(clientConfig *Config) {
1485			// Returning an error should abort the handshake with
1486			// that error.
1487			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1488				return nil, errors.New("GetClientCertificate")
1489			}
1490		},
1491		"GetClientCertificate",
1492		func(t *testing.T, testNum int, cs *ConnectionState) {
1493		},
1494	},
1495	{
1496		func(clientConfig *Config) {
1497			clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
1498				return &testConfig.Certificates[0], nil
1499			}
1500		},
1501		"",
1502		func(t *testing.T, testNum int, cs *ConnectionState) {
1503			if l := len(cs.VerifiedChains); l != 0 {
1504				t.Errorf("#%d: expected some verified chains, but found none", testNum)
1505			}
1506		},
1507	},
1508}
1509
1510func TestGetClientCertificate(t *testing.T) {
1511	issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
1512	if err != nil {
1513		panic(err)
1514	}
1515
1516	for i, test := range getClientCertificateTests {
1517		serverConfig := testConfig.Clone()
1518		serverConfig.ClientAuth = RequestClientCert
1519		serverConfig.RootCAs = x509.NewCertPool()
1520		serverConfig.RootCAs.AddCert(issuer)
1521
1522		clientConfig := testConfig.Clone()
1523
1524		test.setup(clientConfig)
1525
1526		type serverResult struct {
1527			cs  ConnectionState
1528			err error
1529		}
1530
1531		c, s := net.Pipe()
1532		done := make(chan serverResult)
1533
1534		go func() {
1535			defer s.Close()
1536			server := Server(s, serverConfig)
1537			err := server.Handshake()
1538
1539			var cs ConnectionState
1540			if err == nil {
1541				cs = server.ConnectionState()
1542			}
1543			done <- serverResult{cs, err}
1544		}()
1545
1546		clientErr := Client(c, clientConfig).Handshake()
1547		c.Close()
1548
1549		result := <-done
1550
1551		if clientErr != nil {
1552			if len(test.expectedClientError) == 0 {
1553				t.Errorf("#%d: client error: %v", i, clientErr)
1554			} else if got := clientErr.Error(); got != test.expectedClientError {
1555				t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
1556			}
1557		} else if len(test.expectedClientError) > 0 {
1558			t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
1559		} else if err := result.err; err != nil {
1560			t.Errorf("#%d: server error: %v", i, err)
1561		} else {
1562			test.verify(t, i, &result.cs)
1563		}
1564	}
1565}
1566