xref: /openbsd/regress/lib/libtls/gotls/tls_test.go (revision 274d7c50)
1package tls
2
3import (
4	"crypto/tls"
5	"encoding/pem"
6	"fmt"
7	"io/ioutil"
8	"net/http"
9	"net/http/httptest"
10	"net/url"
11	"os"
12	"strings"
13	"testing"
14	"time"
15)
16
17const (
18	httpContent = "Hello, TLS!"
19
20	certHash = "SHA256:448f628a8a65aa18560e53a80c53acb38c51b427df0334082349141147dc9bf6"
21)
22
23var (
24	certNotBefore = time.Unix(0, 0)
25	certNotAfter  = certNotBefore.Add(1000000 * time.Hour)
26)
27
28type handshakeError string
29
30func (he handshakeError) Error() string {
31	return string(he)
32}
33
34// createCAFile writes a PEM encoded version of the certificate out to a
35// temporary file, for use by libtls.
36func createCAFile(cert []byte) (string, error) {
37	f, err := ioutil.TempFile("", "tls")
38	if err != nil {
39		return "", fmt.Errorf("failed to create file: %v", err)
40	}
41	defer f.Close()
42	block := &pem.Block{
43		Type:  "CERTIFICATE",
44		Bytes: cert,
45	}
46	if err := pem.Encode(f, block); err != nil {
47		return "", fmt.Errorf("failed to encode certificate: %v", err)
48	}
49	return f.Name(), nil
50}
51
52func newTestServer(tlsCfg *tls.Config) (*httptest.Server, *url.URL, string, error) {
53	ts := httptest.NewUnstartedServer(
54		http.HandlerFunc(
55			func(w http.ResponseWriter, r *http.Request) {
56				fmt.Fprintln(w, httpContent)
57			},
58		),
59	)
60	ts.TLS = tlsCfg
61	ts.StartTLS()
62
63	u, err := url.Parse(ts.URL)
64	if err != nil {
65		return nil, nil, "", fmt.Errorf("failed to parse URL %q: %v", ts.URL, err)
66	}
67
68	caFile, err := createCAFile(ts.TLS.Certificates[0].Certificate[0])
69	if err != nil {
70		return nil, nil, "", fmt.Errorf("failed to create CA file: %v", err)
71	}
72
73	return ts, u, caFile, nil
74}
75
76func handshakeVersionTest(tlsCfg *tls.Config) (ProtocolVersion, error) {
77	ts, u, caFile, err := newTestServer(tlsCfg)
78	if err != nil {
79		return 0, fmt.Errorf("failed to start test server: %v", err)
80	}
81	defer os.Remove(caFile)
82	defer ts.Close()
83
84	if err := Init(); err != nil {
85		return 0, err
86	}
87
88	cfg, err := NewConfig()
89	if err != nil {
90		return 0, err
91	}
92	defer cfg.Free()
93	if err := cfg.SetCAFile(caFile); err != nil {
94		return 0, err
95	}
96	if err := cfg.SetCiphers("compat"); err != nil {
97		return 0, err
98	}
99	if err := cfg.SetProtocols(ProtocolsAll); err != nil {
100		return 0, err
101	}
102
103	tls, err := NewClient(cfg)
104	if err != nil {
105		return 0, err
106	}
107	defer tls.Free()
108
109	if err := tls.Connect(u.Host, ""); err != nil {
110		return 0, err
111	}
112	if err := tls.Handshake(); err != nil {
113		return 0, handshakeError(err.Error())
114	}
115	version, err := tls.ConnVersion()
116	if err != nil {
117		return 0, err
118	}
119	if err := tls.Close(); err != nil {
120		return 0, err
121	}
122	return version, nil
123}
124
125func TestTLSBasic(t *testing.T) {
126	ts, u, caFile, err := newTestServer(nil)
127	if err != nil {
128		t.Fatalf("Failed to start test server: %v", err)
129	}
130	defer os.Remove(caFile)
131	defer ts.Close()
132
133	if err := Init(); err != nil {
134		t.Fatal(err)
135	}
136
137	cfg, err := NewConfig()
138	if err != nil {
139		t.Fatal(err)
140	}
141	defer cfg.Free()
142	if err := cfg.SetCAFile(caFile); err != nil {
143		t.Fatal(err)
144	}
145
146	tls, err := NewClient(cfg)
147	if err != nil {
148		t.Fatal(err)
149	}
150	defer tls.Free()
151
152	t.Logf("Connecting to %s", u.Host)
153
154	if err := tls.Connect(u.Host, ""); err != nil {
155		t.Fatal(err)
156	}
157	defer func() {
158		if err := tls.Close(); err != nil {
159			t.Fatalf("Close failed: %v", err)
160		}
161	}()
162
163	n, err := tls.Write([]byte("GET / HTTP/1.0\n\n"))
164	if err != nil {
165		t.Fatal(err)
166	}
167	t.Logf("Wrote %d bytes...", n)
168
169	buf := make([]byte, 1024)
170	n, err = tls.Read(buf)
171	if err != nil {
172		t.Fatal(err)
173	}
174	t.Logf("Read %d bytes...", n)
175
176	if !strings.Contains(string(buf), httpContent) {
177		t.Errorf("Response does not contain %q", httpContent)
178	}
179}
180
181func TestTLSVersions(t *testing.T) {
182	tests := []struct {
183		minVersion       uint16
184		maxVersion       uint16
185		wantVersion      ProtocolVersion
186		wantHandshakeErr bool
187	}{
188		{tls.VersionSSL30, tls.VersionTLS12, ProtocolTLSv12, false},
189		{tls.VersionTLS10, tls.VersionTLS12, ProtocolTLSv12, false},
190		{tls.VersionTLS11, tls.VersionTLS12, ProtocolTLSv12, false},
191		{tls.VersionSSL30, tls.VersionTLS11, ProtocolTLSv11, false},
192		{tls.VersionSSL30, tls.VersionTLS10, ProtocolTLSv10, false},
193		{tls.VersionSSL30, tls.VersionSSL30, 0, true},
194		{tls.VersionTLS10, tls.VersionTLS10, ProtocolTLSv10, false},
195		{tls.VersionTLS11, tls.VersionTLS11, ProtocolTLSv11, false},
196		{tls.VersionTLS12, tls.VersionTLS12, ProtocolTLSv12, false},
197	}
198	for i, test := range tests {
199		t.Logf("Testing handshake with protocols %x:%x", test.minVersion, test.maxVersion)
200		tlsCfg := &tls.Config{
201			MinVersion: test.minVersion,
202			MaxVersion: test.maxVersion,
203		}
204		version, err := handshakeVersionTest(tlsCfg)
205		switch {
206		case test.wantHandshakeErr && err == nil:
207			t.Errorf("Test %d - handshake %x:%x succeeded, want handshake error",
208				i, test.minVersion, test.maxVersion)
209		case test.wantHandshakeErr && err != nil:
210			if _, ok := err.(handshakeError); !ok {
211				t.Errorf("Test %d - handshake %x:%x; got unknown error, want handshake error: %v",
212					i, test.minVersion, test.maxVersion, err)
213			}
214		case !test.wantHandshakeErr && err != nil:
215			t.Errorf("Test %d - handshake %x:%x failed: %v", i, test.minVersion, test.maxVersion, err)
216		case !test.wantHandshakeErr && err == nil:
217			if got, want := version, test.wantVersion; got != want {
218				t.Errorf("Test %d - handshake %x:%x; got protocol version %v, want %v",
219					i, test.minVersion, test.maxVersion, got, want)
220			}
221		}
222	}
223}
224
225func TestTLSSingleByteReadWrite(t *testing.T) {
226	ts, u, caFile, err := newTestServer(nil)
227	if err != nil {
228		t.Fatalf("Failed to start test server: %v", err)
229	}
230	defer os.Remove(caFile)
231	defer ts.Close()
232
233	if err := Init(); err != nil {
234		t.Fatal(err)
235	}
236
237	cfg, err := NewConfig()
238	if err != nil {
239		t.Fatal(err)
240	}
241	defer cfg.Free()
242	if err := cfg.SetCAFile(caFile); err != nil {
243		t.Fatal(err)
244	}
245
246	tls, err := NewClient(cfg)
247	if err != nil {
248		t.Fatal(err)
249	}
250	defer tls.Free()
251
252	t.Logf("Connecting to %s", u.Host)
253
254	if err := tls.Connect(u.Host, ""); err != nil {
255		t.Fatal(err)
256	}
257	defer func() {
258		if err := tls.Close(); err != nil {
259			t.Fatalf("Close failed: %v", err)
260		}
261	}()
262
263	for _, b := range []byte("GET / HTTP/1.0\n\n") {
264		n, err := tls.Write([]byte{b})
265		if err != nil {
266			t.Fatal(err)
267		}
268		if n != 1 {
269			t.Fatalf("Wrote byte %v, got length %d, want 1", b, n)
270		}
271	}
272
273	var body []byte
274	for {
275		buf := make([]byte, 1)
276		n, err := tls.Read(buf)
277		if err != nil {
278			t.Fatal(err)
279		}
280		if n == 0 {
281			break
282		}
283		if n != 1 {
284			t.Fatalf("Read single byte, got length %d, want 1", n)
285		}
286		body = append(body, buf...)
287	}
288
289	if !strings.Contains(string(body), httpContent) {
290		t.Errorf("Response does not contain %q", httpContent)
291	}
292}
293
294func TestTLSInfo(t *testing.T) {
295	ts, u, caFile, err := newTestServer(nil)
296	if err != nil {
297		t.Fatalf("Failed to start test server: %v", err)
298	}
299	defer os.Remove(caFile)
300	defer ts.Close()
301
302	if err := Init(); err != nil {
303		t.Fatal(err)
304	}
305
306	cfg, err := NewConfig()
307	if err != nil {
308		t.Fatal(err)
309	}
310	defer cfg.Free()
311	if err := cfg.SetCAFile(caFile); err != nil {
312		t.Fatal(err)
313	}
314
315	tls, err := NewClient(cfg)
316	if err != nil {
317		t.Fatal(err)
318	}
319	defer tls.Free()
320
321	t.Logf("Connecting to %s", u.Host)
322
323	if err := tls.Connect(u.Host, ""); err != nil {
324		t.Fatal(err)
325	}
326	defer func() {
327		if err := tls.Close(); err != nil {
328			t.Fatalf("Close failed: %v", err)
329		}
330	}()
331
332	// All of these should fail since the handshake has not completed.
333	if _, err := tls.ConnVersion(); err == nil {
334		t.Error("ConnVersion() return nil error, want error")
335	}
336	if _, err := tls.ConnCipher(); err == nil {
337		t.Error("ConnCipher() return nil error, want error")
338	}
339	if _, err := tls.ConnCipherStrength(); err == nil {
340		t.Error("ConnCipherStrength() return nil error, want error")
341	}
342
343	if got, want := tls.PeerCertProvided(), false; got != want {
344		t.Errorf("PeerCertProvided() = %v, want %v", got, want)
345	}
346	for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
347		if got, want := tls.PeerCertContainsName(name), false; got != want {
348			t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
349		}
350	}
351
352	if _, err := tls.PeerCertIssuer(); err == nil {
353		t.Error("PeerCertIssuer() returned nil error, want error")
354	}
355	if _, err := tls.PeerCertSubject(); err == nil {
356		t.Error("PeerCertSubject() returned nil error, want error")
357	}
358	if _, err := tls.PeerCertHash(); err == nil {
359		t.Error("PeerCertHash() returned nil error, want error")
360	}
361	if _, err := tls.PeerCertNotBefore(); err == nil {
362		t.Error("PeerCertNotBefore() returned nil error, want error")
363	}
364	if _, err := tls.PeerCertNotAfter(); err == nil {
365		t.Error("PeerCertNotAfter() returned nil error, want error")
366	}
367
368	// Complete the handshake...
369	if err := tls.Handshake(); err != nil {
370		t.Fatalf("Handshake failed: %v", err)
371	}
372
373	if version, err := tls.ConnVersion(); err != nil {
374		t.Errorf("ConnVersion() returned error: %v", err)
375	} else {
376		t.Logf("Protocol version: %v", version)
377	}
378	if cipher, err := tls.ConnCipher(); err != nil {
379		t.Errorf("ConnCipher() returned error: %v", err)
380	} else {
381		t.Logf("Cipher: %v", cipher)
382	}
383	if strength, err := tls.ConnCipherStrength(); err != nil {
384		t.Errorf("ConnCipherStrength() return ederror: %v", err)
385	} else {
386		t.Logf("Cipher Strength: %v bits", strength)
387	}
388
389	if got, want := tls.PeerCertProvided(), true; got != want {
390		t.Errorf("PeerCertProvided() = %v, want %v", got, want)
391	}
392	for _, name := range []string{"127.0.0.1", "::1", "example.com"} {
393		if got, want := tls.PeerCertContainsName(name), true; got != want {
394			t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want)
395		}
396	}
397
398	if issuer, err := tls.PeerCertIssuer(); err != nil {
399		t.Errorf("PeerCertIssuer() returned error: %v", err)
400	} else {
401		t.Logf("Issuer: %v", issuer)
402	}
403	if subject, err := tls.PeerCertSubject(); err != nil {
404		t.Errorf("PeerCertSubject() returned error: %v", err)
405	} else {
406		t.Logf("Subject: %v", subject)
407	}
408	if hash, err := tls.PeerCertHash(); err != nil {
409		t.Errorf("PeerCertHash() returned error: %v", err)
410	} else if hash != certHash {
411		t.Errorf("Got cert hash %q, want %q", hash, certHash)
412	} else {
413		t.Logf("Hash: %v", hash)
414	}
415	if notBefore, err := tls.PeerCertNotBefore(); err != nil {
416		t.Errorf("PeerCertNotBefore() returned error: %v", err)
417	} else if !certNotBefore.Equal(notBefore) {
418		t.Errorf("Got cert notBefore %v, want %v", notBefore.UTC(), certNotBefore.UTC())
419	} else {
420		t.Logf("NotBefore: %v", notBefore.UTC())
421	}
422	if notAfter, err := tls.PeerCertNotAfter(); err != nil {
423		t.Errorf("PeerCertNotAfter() returned error: %v", err)
424	} else if !certNotAfter.Equal(notAfter) {
425		t.Errorf("Got cert notAfter %v, want %v", notAfter.UTC(), certNotAfter.UTC())
426	} else {
427		t.Logf("NotAfter: %v", notAfter.UTC())
428	}
429}
430