1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package tls
6
7import (
8	"bufio"
9	"encoding/hex"
10	"errors"
11	"flag"
12	"fmt"
13	"io"
14	"io/ioutil"
15	"net"
16	"os"
17	"os/exec"
18	"strconv"
19	"strings"
20	"sync"
21	"testing"
22)
23
24// TLS reference tests run a connection against a reference implementation
25// (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
26// code, during a test, is configured with deterministic randomness and so the
27// reference test can be reproduced exactly in the future.
28//
29// In order to save everyone who wishes to run the tests from needing the
30// reference implementation installed, the reference connections are saved in
31// files in the testdata directory. Thus running the tests involves nothing
32// external, but creating and updating them requires the reference
33// implementation.
34//
35// Tests can be updated by running them with the -update flag. This will cause
36// the test files to be regenerated. Generally one should combine the -update
37// flag with -test.run to updated a specific test. Since the reference
38// implementation will always generate fresh random numbers, large parts of
39// the reference connection will always change.
40
41var (
42	update = flag.Bool("update", false, "update golden files on disk")
43
44	opensslVersionTestOnce sync.Once
45	opensslVersionTestErr  error
46)
47
48func checkOpenSSLVersion(t *testing.T) {
49	opensslVersionTestOnce.Do(testOpenSSLVersion)
50	if opensslVersionTestErr != nil {
51		t.Fatal(opensslVersionTestErr)
52	}
53}
54
55func testOpenSSLVersion() {
56	// This test ensures that the version of OpenSSL looks reasonable
57	// before updating the test data.
58
59	if !*update {
60		return
61	}
62
63	openssl := exec.Command("openssl", "version")
64	output, err := openssl.CombinedOutput()
65	if err != nil {
66		opensslVersionTestErr = err
67		return
68	}
69
70	version := string(output)
71	if strings.HasPrefix(version, "OpenSSL 1.1.1") {
72		return
73	}
74
75	println("***********************************************")
76	println("")
77	println("You need to build OpenSSL 1.1.1 from source in order")
78	println("to update the test data.")
79	println("")
80	println("Configure it with:")
81	println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method")
82	println("and then add the apps/ directory at the front of your PATH.")
83	println("***********************************************")
84
85	opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data")
86}
87
88// recordingConn is a net.Conn that records the traffic that passes through it.
89// WriteTo can be used to produce output that can be later be loaded with
90// ParseTestData.
91type recordingConn struct {
92	net.Conn
93	sync.Mutex
94	flows   [][]byte
95	reading bool
96}
97
98func (r *recordingConn) Read(b []byte) (n int, err error) {
99	if n, err = r.Conn.Read(b); n == 0 {
100		return
101	}
102	b = b[:n]
103
104	r.Lock()
105	defer r.Unlock()
106
107	if l := len(r.flows); l == 0 || !r.reading {
108		buf := make([]byte, len(b))
109		copy(buf, b)
110		r.flows = append(r.flows, buf)
111	} else {
112		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
113	}
114	r.reading = true
115	return
116}
117
118func (r *recordingConn) Write(b []byte) (n int, err error) {
119	if n, err = r.Conn.Write(b); n == 0 {
120		return
121	}
122	b = b[:n]
123
124	r.Lock()
125	defer r.Unlock()
126
127	if l := len(r.flows); l == 0 || r.reading {
128		buf := make([]byte, len(b))
129		copy(buf, b)
130		r.flows = append(r.flows, buf)
131	} else {
132		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
133	}
134	r.reading = false
135	return
136}
137
138// WriteTo writes Go source code to w that contains the recorded traffic.
139func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
140	// TLS always starts with a client to server flow.
141	clientToServer := true
142	var written int64
143	for i, flow := range r.flows {
144		source, dest := "client", "server"
145		if !clientToServer {
146			source, dest = dest, source
147		}
148		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
149		written += int64(n)
150		if err != nil {
151			return written, err
152		}
153		dumper := hex.Dumper(w)
154		n, err = dumper.Write(flow)
155		written += int64(n)
156		if err != nil {
157			return written, err
158		}
159		err = dumper.Close()
160		if err != nil {
161			return written, err
162		}
163		clientToServer = !clientToServer
164	}
165	return written, nil
166}
167
168func parseTestData(r io.Reader) (flows [][]byte, err error) {
169	var currentFlow []byte
170
171	scanner := bufio.NewScanner(r)
172	for scanner.Scan() {
173		line := scanner.Text()
174		// If the line starts with ">>> " then it marks the beginning
175		// of a new flow.
176		if strings.HasPrefix(line, ">>> ") {
177			if len(currentFlow) > 0 || len(flows) > 0 {
178				flows = append(flows, currentFlow)
179				currentFlow = nil
180			}
181			continue
182		}
183
184		// Otherwise the line is a line of hex dump that looks like:
185		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
186		// (Some bytes have been omitted from the middle section.)
187
188		if i := strings.IndexByte(line, ' '); i >= 0 {
189			line = line[i:]
190		} else {
191			return nil, errors.New("invalid test data")
192		}
193
194		if i := strings.IndexByte(line, '|'); i >= 0 {
195			line = line[:i]
196		} else {
197			return nil, errors.New("invalid test data")
198		}
199
200		hexBytes := strings.Fields(line)
201		for _, hexByte := range hexBytes {
202			val, err := strconv.ParseUint(hexByte, 16, 8)
203			if err != nil {
204				return nil, errors.New("invalid hex byte in test data: " + err.Error())
205			}
206			currentFlow = append(currentFlow, byte(val))
207		}
208	}
209
210	if len(currentFlow) > 0 {
211		flows = append(flows, currentFlow)
212	}
213
214	return flows, nil
215}
216
217// tempFile creates a temp file containing contents and returns its path.
218func tempFile(contents string) string {
219	file, err := ioutil.TempFile("", "go-tls-test")
220	if err != nil {
221		panic("failed to create temp file: " + err.Error())
222	}
223	path := file.Name()
224	file.WriteString(contents)
225	file.Close()
226	return path
227}
228
229// localListener is set up by TestMain and used by localPipe to create Conn
230// pairs like net.Pipe, but connected by an actual buffered TCP connection.
231var localListener struct {
232	sync.Mutex
233	net.Listener
234}
235
236func localPipe(t testing.TB) (net.Conn, net.Conn) {
237	localListener.Lock()
238	defer localListener.Unlock()
239	c := make(chan net.Conn)
240	go func() {
241		conn, err := localListener.Accept()
242		if err != nil {
243			t.Errorf("Failed to accept local connection: %v", err)
244		}
245		c <- conn
246	}()
247	addr := localListener.Addr()
248	c1, err := net.Dial(addr.Network(), addr.String())
249	if err != nil {
250		t.Fatalf("Failed to dial local connection: %v", err)
251	}
252	c2 := <-c
253	return c1, c2
254}
255
256func TestMain(m *testing.M) {
257	l, err := net.Listen("tcp", "127.0.0.1:0")
258	if err != nil {
259		l, err = net.Listen("tcp6", "[::1]:0")
260	}
261	if err != nil {
262		fmt.Fprintf(os.Stderr, "Failed to open local listener: %v", err)
263		os.Exit(1)
264	}
265	localListener.Listener = l
266	exitCode := m.Run()
267	localListener.Close()
268	os.Exit(exitCode)
269}
270