1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bytes"
9	"io"
10	"net"
11	"testing"
12)
13
14func TestRoundUp(t *testing.T) {
15	if roundUp(0, 16) != 0 ||
16		roundUp(1, 16) != 16 ||
17		roundUp(15, 16) != 16 ||
18		roundUp(16, 16) != 16 ||
19		roundUp(17, 16) != 32 {
20		t.Error("roundUp broken")
21	}
22}
23
24// will be initialized with {0, 255, 255, ..., 255}
25var padding255Bad = [256]byte{}
26
27// will be initialized with {255, 255, 255, ..., 255}
28var padding255Good = [256]byte{255}
29
30var paddingTests = []struct {
31	in          []byte
32	good        bool
33	expectedLen int
34}{
35	{[]byte{1, 2, 3, 4, 0}, true, 4},
36	{[]byte{1, 2, 3, 4, 0, 1}, false, 0},
37	{[]byte{1, 2, 3, 4, 99, 99}, false, 0},
38	{[]byte{1, 2, 3, 4, 1, 1}, true, 4},
39	{[]byte{1, 2, 3, 2, 2, 2}, true, 3},
40	{[]byte{1, 2, 3, 3, 3, 3}, true, 2},
41	{[]byte{1, 2, 3, 4, 3, 3}, false, 0},
42	{[]byte{1, 4, 4, 4, 4, 4}, true, 1},
43	{[]byte{5, 5, 5, 5, 5, 5}, true, 0},
44	{[]byte{6, 6, 6, 6, 6, 6}, false, 0},
45	{padding255Bad[:], false, 0},
46	{padding255Good[:], true, 0},
47}
48
49func TestRemovePadding(t *testing.T) {
50	for i := 1; i < len(padding255Bad); i++ {
51		padding255Bad[i] = 255
52		padding255Good[i] = 255
53	}
54	for i, test := range paddingTests {
55		paddingLen, good := extractPadding(test.in)
56		expectedGood := byte(255)
57		if !test.good {
58			expectedGood = 0
59		}
60		if good != expectedGood {
61			t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good)
62		}
63		if good == 255 && len(test.in)-paddingLen != test.expectedLen {
64			t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen)
65		}
66	}
67}
68
69var certExampleCom = `308201713082011ba003020102021005a75ddf21014d5f417083b7a010ba2e300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343135335a170d3137303831373231343135335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b37f0fdd67e715bf532046ac34acbd8fdc4dabe2b598588f3f58b1f12e6219a16cbfe54d2b4b665396013589262360b6721efa27d546854f17cc9aeec6751db10203010001a34d304b300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030160603551d11040f300d820b6578616d706c652e636f6d300d06092a864886f70d01010b050003410059fc487866d3d855503c8e064ca32aac5e9babcece89ec597f8b2b24c17867f4a5d3b4ece06e795bfc5448ccbd2ffca1b3433171ebf3557a4737b020565350a0`
70
71var certWildcardExampleCom = `308201743082011ea003020102021100a7aa6297c9416a4633af8bec2958c607300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343231395a170d3137303831373231343231395a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100b105afc859a711ee864114e7d2d46c2dcbe392d3506249f6c2285b0eb342cc4bf2d803677c61c0abde443f084745c1a6d62080e5664ef2cc8f50ad8a0ab8870b0203010001a34f304d300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff0402300030180603551d110411300f820d2a2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100af26088584d266e3f6566360cf862c7fecc441484b098b107439543144a2b93f20781988281e108c6d7656934e56950e1e5f2bcf38796b814ccb729445856c34`
72
73var certFooExampleCom = `308201753082011fa00302010202101bbdb6070b0aeffc49008cde74deef29300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343234345a170d3137303831373231343234345a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100f00ac69d8ca2829f26216c7b50f1d4bbabad58d447706476cd89a2f3e1859943748aa42c15eedc93ac7c49e40d3b05ed645cb6b81c4efba60d961f44211a54eb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f666f6f2e6578616d706c652e636f6d300d06092a864886f70d01010b0500034100a0957fca6d1e0f1ef4b247348c7a8ca092c29c9c0ecc1898ea6b8065d23af6d922a410dd2335a0ea15edd1394cef9f62c9e876a21e35250a0b4fe1ddceba0f36`
74
75var certDoubleWildcardExampleCom = `308201753082011fa003020102021039d262d8538db8ffba30d204e02ddeb5300d06092a864886f70d01010b050030123110300e060355040a130741636d6520436f301e170d3136303831373231343331335a170d3137303831373231343331335a30123110300e060355040a130741636d6520436f305c300d06092a864886f70d0101010500034b003048024100abb6bd84b8b9be3fb9415d00f22b4ddcaec7c99855b9d818c09003e084578430e5cfd2e35faa3561f036d496aa43a9ca6e6cf23c72a763c04ae324004f6cbdbb0203010001a351304f300e0603551d0f0101ff0404030205a030130603551d25040c300a06082b06010505070301300c0603551d130101ff04023000301a0603551d1104133011820f2a2e2a2e6578616d706c652e636f6d300d06092a864886f70d01010b05000341004837521004a5b6bc7ad5d6c0dae60bb7ee0fa5e4825be35e2bb6ef07ee29396ca30ceb289431bcfd363888ba2207139933ac7c6369fa8810c819b2e2966abb4b`
76
77func TestCertificateSelection(t *testing.T) {
78	config := Config{
79		Certificates: []Certificate{
80			{
81				Certificate: [][]byte{fromHex(certExampleCom)},
82			},
83			{
84				Certificate: [][]byte{fromHex(certWildcardExampleCom)},
85			},
86			{
87				Certificate: [][]byte{fromHex(certFooExampleCom)},
88			},
89			{
90				Certificate: [][]byte{fromHex(certDoubleWildcardExampleCom)},
91			},
92		},
93	}
94
95	config.BuildNameToCertificate()
96
97	pointerToIndex := func(c *Certificate) int {
98		for i := range config.Certificates {
99			if c == &config.Certificates[i] {
100				return i
101			}
102		}
103		return -1
104	}
105
106	certificateForName := func(name string) *Certificate {
107		clientHello := &ClientHelloInfo{
108			ServerName: name,
109		}
110		if cert, err := config.getCertificate(clientHello); err != nil {
111			t.Errorf("unable to get certificate for name '%s': %s", name, err)
112			return nil
113		} else {
114			return cert
115		}
116	}
117
118	if n := pointerToIndex(certificateForName("example.com")); n != 0 {
119		t.Errorf("example.com returned certificate %d, not 0", n)
120	}
121	if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
122		t.Errorf("bar.example.com returned certificate %d, not 1", n)
123	}
124	if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
125		t.Errorf("foo.example.com returned certificate %d, not 2", n)
126	}
127	if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 {
128		t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
129	}
130	if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 {
131		t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
132	}
133}
134
135// Run with multiple crypto configs to test the logic for computing TLS record overheads.
136func runDynamicRecordSizingTest(t *testing.T, config *Config) {
137	clientConn, serverConn := localPipe(t)
138
139	serverConfig := config.Clone()
140	serverConfig.DynamicRecordSizingDisabled = false
141	tlsConn := Server(serverConn, serverConfig)
142
143	handshakeDone := make(chan struct{})
144	recordSizesChan := make(chan []int, 1)
145	defer func() { <-recordSizesChan }() // wait for the goroutine to exit
146	go func() {
147		// This goroutine performs a TLS handshake over clientConn and
148		// then reads TLS records until EOF. It writes a slice that
149		// contains all the record sizes to recordSizesChan.
150		defer close(recordSizesChan)
151		defer clientConn.Close()
152
153		tlsConn := Client(clientConn, config)
154		if err := tlsConn.Handshake(); err != nil {
155			t.Errorf("Error from client handshake: %v", err)
156			return
157		}
158		close(handshakeDone)
159
160		var recordHeader [recordHeaderLen]byte
161		var record []byte
162		var recordSizes []int
163
164		for {
165			n, err := io.ReadFull(clientConn, recordHeader[:])
166			if err == io.EOF {
167				break
168			}
169			if err != nil || n != len(recordHeader) {
170				t.Errorf("io.ReadFull = %d, %v", n, err)
171				return
172			}
173
174			length := int(recordHeader[3])<<8 | int(recordHeader[4])
175			if len(record) < length {
176				record = make([]byte, length)
177			}
178
179			n, err = io.ReadFull(clientConn, record[:length])
180			if err != nil || n != length {
181				t.Errorf("io.ReadFull = %d, %v", n, err)
182				return
183			}
184
185			recordSizes = append(recordSizes, recordHeaderLen+length)
186		}
187
188		recordSizesChan <- recordSizes
189	}()
190
191	if err := tlsConn.Handshake(); err != nil {
192		t.Fatalf("Error from server handshake: %s", err)
193	}
194	<-handshakeDone
195
196	// The server writes these plaintexts in order.
197	plaintext := bytes.Join([][]byte{
198		bytes.Repeat([]byte("x"), recordSizeBoostThreshold),
199		bytes.Repeat([]byte("y"), maxPlaintext*2),
200		bytes.Repeat([]byte("z"), maxPlaintext),
201	}, nil)
202
203	if _, err := tlsConn.Write(plaintext); err != nil {
204		t.Fatalf("Error from server write: %s", err)
205	}
206	if err := tlsConn.Close(); err != nil {
207		t.Fatalf("Error from server close: %s", err)
208	}
209
210	recordSizes := <-recordSizesChan
211	if recordSizes == nil {
212		t.Fatalf("Client encountered an error")
213	}
214
215	// Drop the size of the second to last record, which is likely to be
216	// truncated, and the last record, which is a close_notify alert.
217	recordSizes = recordSizes[:len(recordSizes)-2]
218
219	// recordSizes should contain a series of records smaller than
220	// tcpMSSEstimate followed by some larger than maxPlaintext.
221	seenLargeRecord := false
222	for i, size := range recordSizes {
223		if !seenLargeRecord {
224			if size > (i+1)*tcpMSSEstimate {
225				t.Fatalf("Record #%d has size %d, which is too large too soon", i, size)
226			}
227			if size >= maxPlaintext {
228				seenLargeRecord = true
229			}
230		} else if size <= maxPlaintext {
231			t.Fatalf("Record #%d has size %d but should be full sized", i, size)
232		}
233	}
234
235	if !seenLargeRecord {
236		t.Fatalf("No large records observed")
237	}
238}
239
240func TestDynamicRecordSizingWithStreamCipher(t *testing.T) {
241	config := testConfig.Clone()
242	config.MaxVersion = VersionTLS12
243	config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA}
244	runDynamicRecordSizingTest(t, config)
245}
246
247func TestDynamicRecordSizingWithCBC(t *testing.T) {
248	config := testConfig.Clone()
249	config.MaxVersion = VersionTLS12
250	config.CipherSuites = []uint16{TLS_RSA_WITH_AES_256_CBC_SHA}
251	runDynamicRecordSizingTest(t, config)
252}
253
254func TestDynamicRecordSizingWithAEAD(t *testing.T) {
255	config := testConfig.Clone()
256	config.MaxVersion = VersionTLS12
257	config.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
258	runDynamicRecordSizingTest(t, config)
259}
260
261func TestDynamicRecordSizingWithTLSv13(t *testing.T) {
262	config := testConfig.Clone()
263	runDynamicRecordSizingTest(t, config)
264}
265
266// hairpinConn is a net.Conn that makes a “hairpin” call when closed, back into
267// the tls.Conn which is calling it.
268type hairpinConn struct {
269	net.Conn
270	tlsConn *Conn
271}
272
273func (conn *hairpinConn) Close() error {
274	conn.tlsConn.ConnectionState()
275	return nil
276}
277
278func TestHairpinInClose(t *testing.T) {
279	// This tests that the underlying net.Conn can call back into the
280	// tls.Conn when being closed without deadlocking.
281	client, server := localPipe(t)
282	defer server.Close()
283	defer client.Close()
284
285	conn := &hairpinConn{client, nil}
286	tlsConn := Server(conn, &Config{
287		GetCertificate: func(*ClientHelloInfo) (*Certificate, error) {
288			panic("unreachable")
289		},
290	})
291	conn.tlsConn = tlsConn
292
293	// This call should not deadlock.
294	tlsConn.Close()
295}
296