1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bufio"
9	"encoding/hex"
10	"errors"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"net"
16	"os/exec"
17	"strconv"
18	"strings"
19	"sync"
20	"testing"
21)
22
23// TLS reference tests run a connection against a reference implementation
24// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
25// code, during a test, is configured with deterministic randomness and so the
26// reference test can be reproduced exactly in the future.
27//
28// In order to save everyone who wishes to run the tests from needing the
29// reference implementation installed, the reference connections are saved in
30// files in the testdata directory. Thus running the tests involves nothing
31// external, but creating and updating them requires the reference
32// implementation.
33//
34// Tests can be updated by running them with the -update flag. This will cause
35// the test files to be regenerated. Generally one should combine the -update
36// flag with -test.run to updated a specific test. Since the reference
37// implementation will always generate fresh random numbers, large parts of
38// the reference connection will always change.
39
40var (
41	update = flag.Bool("update", false, "update golden files on disk")
42
43	opensslVersionTestOnce sync.Once
44	opensslVersionTestErr  error
45)
46
47func checkOpenSSLVersion(t *testing.T) {
48	opensslVersionTestOnce.Do(testOpenSSLVersion)
49	if opensslVersionTestErr != nil {
50		t.Fatal(opensslVersionTestErr)
51	}
52}
53
54func testOpenSSLVersion() {
55	// This test ensures that the version of OpenSSL looks reasonable
56	// before updating the test data.
57
58	if !*update {
59		return
60	}
61
62	openssl := exec.Command("openssl", "version")
63	output, err := openssl.CombinedOutput()
64	if err != nil {
65		opensslVersionTestErr = err
66		return
67	}
68
69	version := string(output)
70	if strings.HasPrefix(version, "OpenSSL 1.1.0") {
71		return
72	}
73
74	println("***********************************************")
75	println("")
76	println("You need to build OpenSSL 1.1.0 from source in order")
77	println("to update the test data.")
78	println("")
79	println("Configure it with:")
80	println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method -static linux-x86_64")
81	println("and then add the apps/ directory at the front of your PATH.")
82	println("***********************************************")
83
84	opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data")
85}
86
87// recordingConn is a net.Conn that records the traffic that passes through it.
88// WriteTo can be used to produce output that can be later be loaded with
89// ParseTestData.
90type recordingConn struct {
91	net.Conn
92	sync.Mutex
93	flows   [][]byte
94	reading bool
95}
96
97func (r *recordingConn) Read(b []byte) (n int, err error) {
98	if n, err = r.Conn.Read(b); n == 0 {
99		return
100	}
101	b = b[:n]
102
103	r.Lock()
104	defer r.Unlock()
105
106	if l := len(r.flows); l == 0 || !r.reading {
107		buf := make([]byte, len(b))
108		copy(buf, b)
109		r.flows = append(r.flows, buf)
110	} else {
111		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
112	}
113	r.reading = true
114	return
115}
116
117func (r *recordingConn) Write(b []byte) (n int, err error) {
118	if n, err = r.Conn.Write(b); n == 0 {
119		return
120	}
121	b = b[:n]
122
123	r.Lock()
124	defer r.Unlock()
125
126	if l := len(r.flows); l == 0 || r.reading {
127		buf := make([]byte, len(b))
128		copy(buf, b)
129		r.flows = append(r.flows, buf)
130	} else {
131		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
132	}
133	r.reading = false
134	return
135}
136
137// WriteTo writes Go source code to w that contains the recorded traffic.
138func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
139	// TLS always starts with a client to server flow.
140	clientToServer := true
141	var written int64
142	for i, flow := range r.flows {
143		source, dest := "client", "server"
144		if !clientToServer {
145			source, dest = dest, source
146		}
147		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
148		written += int64(n)
149		if err != nil {
150			return written, err
151		}
152		dumper := hex.Dumper(w)
153		n, err = dumper.Write(flow)
154		written += int64(n)
155		if err != nil {
156			return written, err
157		}
158		err = dumper.Close()
159		if err != nil {
160			return written, err
161		}
162		clientToServer = !clientToServer
163	}
164	return written, nil
165}
166
167func parseTestData(r io.Reader) (flows [][]byte, err error) {
168	var currentFlow []byte
169
170	scanner := bufio.NewScanner(r)
171	for scanner.Scan() {
172		line := scanner.Text()
173		// If the line starts with ">>> " then it marks the beginning
174		// of a new flow.
175		if strings.HasPrefix(line, ">>> ") {
176			if len(currentFlow) > 0 || len(flows) > 0 {
177				flows = append(flows, currentFlow)
178				currentFlow = nil
179			}
180			continue
181		}
182
183		// Otherwise the line is a line of hex dump that looks like:
184		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
185		// (Some bytes have been omitted from the middle section.)
186
187		if i := strings.IndexByte(line, ' '); i >= 0 {
188			line = line[i:]
189		} else {
190			return nil, errors.New("invalid test data")
191		}
192
193		if i := strings.IndexByte(line, '|'); i >= 0 {
194			line = line[:i]
195		} else {
196			return nil, errors.New("invalid test data")
197		}
198
199		hexBytes := strings.Fields(line)
200		for _, hexByte := range hexBytes {
201			val, err := strconv.ParseUint(hexByte, 16, 8)
202			if err != nil {
203				return nil, errors.New("invalid hex byte in test data: " + err.Error())
204			}
205			currentFlow = append(currentFlow, byte(val))
206		}
207	}
208
209	if len(currentFlow) > 0 {
210		flows = append(flows, currentFlow)
211	}
212
213	return flows, nil
214}
215
216// tempFile creates a temp file containing contents and returns its path.
217func tempFile(contents string) string {
218	file, err := ioutil.TempFile("", "go-tls-test")
219	if err != nil {
220		panic("failed to create temp file: " + err.Error())
221	}
222	path := file.Name()
223	file.WriteString(contents)
224	file.Close()
225	return path
226}
227