1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package nettest
6
7import (
8	"bytes"
9	"encoding/binary"
10	"io"
11	"io/ioutil"
12	"math/rand"
13	"net"
14	"runtime"
15	"sync"
16	"testing"
17	"time"
18)
19
20// MakePipe creates a connection between two endpoints and returns the pair
21// as c1 and c2, such that anything written to c1 is read by c2 and vice-versa.
22// The stop function closes all resources, including c1, c2, and the underlying
23// net.Listener (if there is one), and should not be nil.
24type MakePipe func() (c1, c2 net.Conn, stop func(), err error)
25
26// TestConn tests that a net.Conn implementation properly satisfies the interface.
27// The tests should not produce any false positives, but may experience
28// false negatives. Thus, some issues may only be detected when the test is
29// run multiple times. For maximal effectiveness, run the tests under the
30// race detector.
31func TestConn(t *testing.T, mp MakePipe) {
32	t.Run("BasicIO", func(t *testing.T) { timeoutWrapper(t, mp, testBasicIO) })
33	t.Run("PingPong", func(t *testing.T) { timeoutWrapper(t, mp, testPingPong) })
34	t.Run("RacyRead", func(t *testing.T) { timeoutWrapper(t, mp, testRacyRead) })
35	t.Run("RacyWrite", func(t *testing.T) { timeoutWrapper(t, mp, testRacyWrite) })
36	t.Run("ReadTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testReadTimeout) })
37	t.Run("WriteTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testWriteTimeout) })
38	t.Run("PastTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPastTimeout) })
39	t.Run("PresentTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testPresentTimeout) })
40	t.Run("FutureTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testFutureTimeout) })
41	t.Run("CloseTimeout", func(t *testing.T) { timeoutWrapper(t, mp, testCloseTimeout) })
42	t.Run("ConcurrentMethods", func(t *testing.T) { timeoutWrapper(t, mp, testConcurrentMethods) })
43}
44
45type connTester func(t *testing.T, c1, c2 net.Conn)
46
47func timeoutWrapper(t *testing.T, mp MakePipe, f connTester) {
48	t.Helper()
49	c1, c2, stop, err := mp()
50	if err != nil {
51		t.Fatalf("unable to make pipe: %v", err)
52	}
53	var once sync.Once
54	defer once.Do(func() { stop() })
55	timer := time.AfterFunc(time.Minute, func() {
56		once.Do(func() {
57			t.Error("test timed out; terminating pipe")
58			stop()
59		})
60	})
61	defer timer.Stop()
62	f(t, c1, c2)
63}
64
65// testBasicIO tests that the data sent on c1 is properly received on c2.
66func testBasicIO(t *testing.T, c1, c2 net.Conn) {
67	want := make([]byte, 1<<20)
68	rand.New(rand.NewSource(0)).Read(want)
69
70	dataCh := make(chan []byte)
71	go func() {
72		rd := bytes.NewReader(want)
73		if err := chunkedCopy(c1, rd); err != nil {
74			t.Errorf("unexpected c1.Write error: %v", err)
75		}
76		if err := c1.Close(); err != nil {
77			t.Errorf("unexpected c1.Close error: %v", err)
78		}
79	}()
80
81	go func() {
82		wr := new(bytes.Buffer)
83		if err := chunkedCopy(wr, c2); err != nil {
84			t.Errorf("unexpected c2.Read error: %v", err)
85		}
86		if err := c2.Close(); err != nil {
87			t.Errorf("unexpected c2.Close error: %v", err)
88		}
89		dataCh <- wr.Bytes()
90	}()
91
92	if got := <-dataCh; !bytes.Equal(got, want) {
93		t.Error("transmitted data differs")
94	}
95}
96
97// testPingPong tests that the two endpoints can synchronously send data to
98// each other in a typical request-response pattern.
99func testPingPong(t *testing.T, c1, c2 net.Conn) {
100	var wg sync.WaitGroup
101	defer wg.Wait()
102
103	pingPonger := func(c net.Conn) {
104		defer wg.Done()
105		buf := make([]byte, 8)
106		var prev uint64
107		for {
108			if _, err := io.ReadFull(c, buf); err != nil {
109				if err == io.EOF {
110					break
111				}
112				t.Errorf("unexpected Read error: %v", err)
113			}
114
115			v := binary.LittleEndian.Uint64(buf)
116			binary.LittleEndian.PutUint64(buf, v+1)
117			if prev != 0 && prev+2 != v {
118				t.Errorf("mismatching value: got %d, want %d", v, prev+2)
119			}
120			prev = v
121			if v == 1000 {
122				break
123			}
124
125			if _, err := c.Write(buf); err != nil {
126				t.Errorf("unexpected Write error: %v", err)
127				break
128			}
129		}
130		if err := c.Close(); err != nil {
131			t.Errorf("unexpected Close error: %v", err)
132		}
133	}
134
135	wg.Add(2)
136	go pingPonger(c1)
137	go pingPonger(c2)
138
139	// Start off the chain reaction.
140	if _, err := c1.Write(make([]byte, 8)); err != nil {
141		t.Errorf("unexpected c1.Write error: %v", err)
142	}
143}
144
145// testRacyRead tests that it is safe to mutate the input Read buffer
146// immediately after cancelation has occurred.
147func testRacyRead(t *testing.T, c1, c2 net.Conn) {
148	go chunkedCopy(c2, rand.New(rand.NewSource(0)))
149
150	var wg sync.WaitGroup
151	defer wg.Wait()
152
153	c1.SetReadDeadline(time.Now().Add(time.Millisecond))
154	for i := 0; i < 10; i++ {
155		wg.Add(1)
156		go func() {
157			defer wg.Done()
158
159			b1 := make([]byte, 1024)
160			b2 := make([]byte, 1024)
161			for j := 0; j < 100; j++ {
162				_, err := c1.Read(b1)
163				copy(b1, b2) // Mutate b1 to trigger potential race
164				if err != nil {
165					checkForTimeoutError(t, err)
166					c1.SetReadDeadline(time.Now().Add(time.Millisecond))
167				}
168			}
169		}()
170	}
171}
172
173// testRacyWrite tests that it is safe to mutate the input Write buffer
174// immediately after cancelation has occurred.
175func testRacyWrite(t *testing.T, c1, c2 net.Conn) {
176	go chunkedCopy(ioutil.Discard, c2)
177
178	var wg sync.WaitGroup
179	defer wg.Wait()
180
181	c1.SetWriteDeadline(time.Now().Add(time.Millisecond))
182	for i := 0; i < 10; i++ {
183		wg.Add(1)
184		go func() {
185			defer wg.Done()
186
187			b1 := make([]byte, 1024)
188			b2 := make([]byte, 1024)
189			for j := 0; j < 100; j++ {
190				_, err := c1.Write(b1)
191				copy(b1, b2) // Mutate b1 to trigger potential race
192				if err != nil {
193					checkForTimeoutError(t, err)
194					c1.SetWriteDeadline(time.Now().Add(time.Millisecond))
195				}
196			}
197		}()
198	}
199}
200
201// testReadTimeout tests that Read timeouts do not affect Write.
202func testReadTimeout(t *testing.T, c1, c2 net.Conn) {
203	go chunkedCopy(ioutil.Discard, c2)
204
205	c1.SetReadDeadline(aLongTimeAgo)
206	_, err := c1.Read(make([]byte, 1024))
207	checkForTimeoutError(t, err)
208	if _, err := c1.Write(make([]byte, 1024)); err != nil {
209		t.Errorf("unexpected Write error: %v", err)
210	}
211}
212
213// testWriteTimeout tests that Write timeouts do not affect Read.
214func testWriteTimeout(t *testing.T, c1, c2 net.Conn) {
215	go chunkedCopy(c2, rand.New(rand.NewSource(0)))
216
217	c1.SetWriteDeadline(aLongTimeAgo)
218	_, err := c1.Write(make([]byte, 1024))
219	checkForTimeoutError(t, err)
220	if _, err := c1.Read(make([]byte, 1024)); err != nil {
221		t.Errorf("unexpected Read error: %v", err)
222	}
223}
224
225// testPastTimeout tests that a deadline set in the past immediately times out
226// Read and Write requests.
227func testPastTimeout(t *testing.T, c1, c2 net.Conn) {
228	go chunkedCopy(c2, c2)
229
230	testRoundtrip(t, c1)
231
232	c1.SetDeadline(aLongTimeAgo)
233	n, err := c1.Write(make([]byte, 1024))
234	if n != 0 {
235		t.Errorf("unexpected Write count: got %d, want 0", n)
236	}
237	checkForTimeoutError(t, err)
238	n, err = c1.Read(make([]byte, 1024))
239	if n != 0 {
240		t.Errorf("unexpected Read count: got %d, want 0", n)
241	}
242	checkForTimeoutError(t, err)
243
244	testRoundtrip(t, c1)
245}
246
247// testPresentTimeout tests that a past deadline set while there are pending
248// Read and Write operations immediately times out those operations.
249func testPresentTimeout(t *testing.T, c1, c2 net.Conn) {
250	var wg sync.WaitGroup
251	defer wg.Wait()
252	wg.Add(3)
253
254	deadlineSet := make(chan bool, 1)
255	go func() {
256		defer wg.Done()
257		time.Sleep(100 * time.Millisecond)
258		deadlineSet <- true
259		c1.SetReadDeadline(aLongTimeAgo)
260		c1.SetWriteDeadline(aLongTimeAgo)
261	}()
262	go func() {
263		defer wg.Done()
264		n, err := c1.Read(make([]byte, 1024))
265		if n != 0 {
266			t.Errorf("unexpected Read count: got %d, want 0", n)
267		}
268		checkForTimeoutError(t, err)
269		if len(deadlineSet) == 0 {
270			t.Error("Read timed out before deadline is set")
271		}
272	}()
273	go func() {
274		defer wg.Done()
275		var err error
276		for err == nil {
277			_, err = c1.Write(make([]byte, 1024))
278		}
279		checkForTimeoutError(t, err)
280		if len(deadlineSet) == 0 {
281			t.Error("Write timed out before deadline is set")
282		}
283	}()
284}
285
286// testFutureTimeout tests that a future deadline will eventually time out
287// Read and Write operations.
288func testFutureTimeout(t *testing.T, c1, c2 net.Conn) {
289	var wg sync.WaitGroup
290	wg.Add(2)
291
292	c1.SetDeadline(time.Now().Add(100 * time.Millisecond))
293	go func() {
294		defer wg.Done()
295		_, err := c1.Read(make([]byte, 1024))
296		checkForTimeoutError(t, err)
297	}()
298	go func() {
299		defer wg.Done()
300		var err error
301		for err == nil {
302			_, err = c1.Write(make([]byte, 1024))
303		}
304		checkForTimeoutError(t, err)
305	}()
306	wg.Wait()
307
308	go chunkedCopy(c2, c2)
309	resyncConn(t, c1)
310	testRoundtrip(t, c1)
311}
312
313// testCloseTimeout tests that calling Close immediately times out pending
314// Read and Write operations.
315func testCloseTimeout(t *testing.T, c1, c2 net.Conn) {
316	go chunkedCopy(c2, c2)
317
318	var wg sync.WaitGroup
319	defer wg.Wait()
320	wg.Add(3)
321
322	// Test for cancelation upon connection closure.
323	c1.SetDeadline(neverTimeout)
324	go func() {
325		defer wg.Done()
326		time.Sleep(100 * time.Millisecond)
327		c1.Close()
328	}()
329	go func() {
330		defer wg.Done()
331		var err error
332		buf := make([]byte, 1024)
333		for err == nil {
334			_, err = c1.Read(buf)
335		}
336	}()
337	go func() {
338		defer wg.Done()
339		var err error
340		buf := make([]byte, 1024)
341		for err == nil {
342			_, err = c1.Write(buf)
343		}
344	}()
345}
346
347// testConcurrentMethods tests that the methods of net.Conn can safely
348// be called concurrently.
349func testConcurrentMethods(t *testing.T, c1, c2 net.Conn) {
350	if runtime.GOOS == "plan9" {
351		t.Skip("skipping on plan9; see https://golang.org/issue/20489")
352	}
353	go chunkedCopy(c2, c2)
354
355	// The results of the calls may be nonsensical, but this should
356	// not trigger a race detector warning.
357	var wg sync.WaitGroup
358	for i := 0; i < 100; i++ {
359		wg.Add(7)
360		go func() {
361			defer wg.Done()
362			c1.Read(make([]byte, 1024))
363		}()
364		go func() {
365			defer wg.Done()
366			c1.Write(make([]byte, 1024))
367		}()
368		go func() {
369			defer wg.Done()
370			c1.SetDeadline(time.Now().Add(10 * time.Millisecond))
371		}()
372		go func() {
373			defer wg.Done()
374			c1.SetReadDeadline(aLongTimeAgo)
375		}()
376		go func() {
377			defer wg.Done()
378			c1.SetWriteDeadline(aLongTimeAgo)
379		}()
380		go func() {
381			defer wg.Done()
382			c1.LocalAddr()
383		}()
384		go func() {
385			defer wg.Done()
386			c1.RemoteAddr()
387		}()
388	}
389	wg.Wait() // At worst, the deadline is set 10ms into the future
390
391	resyncConn(t, c1)
392	testRoundtrip(t, c1)
393}
394
395// checkForTimeoutError checks that the error satisfies the Error interface
396// and that Timeout returns true.
397func checkForTimeoutError(t *testing.T, err error) {
398	t.Helper()
399	if nerr, ok := err.(net.Error); ok {
400		if !nerr.Timeout() {
401			t.Errorf("err.Timeout() = false, want true")
402		}
403	} else {
404		t.Errorf("got %T, want net.Error", err)
405	}
406}
407
408// testRoundtrip writes something into c and reads it back.
409// It assumes that everything written into c is echoed back to itself.
410func testRoundtrip(t *testing.T, c net.Conn) {
411	t.Helper()
412	if err := c.SetDeadline(neverTimeout); err != nil {
413		t.Errorf("roundtrip SetDeadline error: %v", err)
414	}
415
416	const s = "Hello, world!"
417	buf := []byte(s)
418	if _, err := c.Write(buf); err != nil {
419		t.Errorf("roundtrip Write error: %v", err)
420	}
421	if _, err := io.ReadFull(c, buf); err != nil {
422		t.Errorf("roundtrip Read error: %v", err)
423	}
424	if string(buf) != s {
425		t.Errorf("roundtrip data mismatch: got %q, want %q", buf, s)
426	}
427}
428
429// resyncConn resynchronizes the connection into a sane state.
430// It assumes that everything written into c is echoed back to itself.
431// It assumes that 0xff is not currently on the wire or in the read buffer.
432func resyncConn(t *testing.T, c net.Conn) {
433	t.Helper()
434	c.SetDeadline(neverTimeout)
435	errCh := make(chan error)
436	go func() {
437		_, err := c.Write([]byte{0xff})
438		errCh <- err
439	}()
440	buf := make([]byte, 1024)
441	for {
442		n, err := c.Read(buf)
443		if n > 0 && bytes.IndexByte(buf[:n], 0xff) == n-1 {
444			break
445		}
446		if err != nil {
447			t.Errorf("unexpected Read error: %v", err)
448			break
449		}
450	}
451	if err := <-errCh; err != nil {
452		t.Errorf("unexpected Write error: %v", err)
453	}
454}
455
456// chunkedCopy copies from r to w in fixed-width chunks to avoid
457// causing a Write that exceeds the maximum packet size for packet-based
458// connections like "unixpacket".
459// We assume that the maximum packet size is at least 1024.
460func chunkedCopy(w io.Writer, r io.Reader) error {
461	b := make([]byte, 1024)
462	_, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b)
463	return err
464}
465