1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build !js
6
7package net
8
9import (
10	"bytes"
11	"crypto/sha256"
12	"encoding/hex"
13	"fmt"
14	"io"
15	"io/ioutil"
16	"os"
17	"runtime"
18	"sync"
19	"testing"
20	"time"
21)
22
23const (
24	newton       = "../testdata/Isaac.Newton-Opticks.txt"
25	newtonLen    = 567198
26	newtonSHA256 = "d4a9ac22462b35e7821a4f2706c211093da678620a8f9997989ee7cf8d507bbd"
27)
28
29func TestSendfile(t *testing.T) {
30	ln, err := newLocalListener("tcp")
31	if err != nil {
32		t.Fatal(err)
33	}
34	defer ln.Close()
35
36	errc := make(chan error, 1)
37	go func(ln Listener) {
38		// Wait for a connection.
39		conn, err := ln.Accept()
40		if err != nil {
41			errc <- err
42			close(errc)
43			return
44		}
45
46		go func() {
47			defer close(errc)
48			defer conn.Close()
49
50			f, err := os.Open(newton)
51			if err != nil {
52				errc <- err
53				return
54			}
55			defer f.Close()
56
57			// Return file data using io.Copy, which should use
58			// sendFile if available.
59			sbytes, err := io.Copy(conn, f)
60			if err != nil {
61				errc <- err
62				return
63			}
64
65			if sbytes != newtonLen {
66				errc <- fmt.Errorf("sent %d bytes; expected %d", sbytes, newtonLen)
67				return
68			}
69		}()
70	}(ln)
71
72	// Connect to listener to retrieve file and verify digest matches
73	// expected.
74	c, err := Dial("tcp", ln.Addr().String())
75	if err != nil {
76		t.Fatal(err)
77	}
78	defer c.Close()
79
80	h := sha256.New()
81	rbytes, err := io.Copy(h, c)
82	if err != nil {
83		t.Error(err)
84	}
85
86	if rbytes != newtonLen {
87		t.Errorf("received %d bytes; expected %d", rbytes, newtonLen)
88	}
89
90	if res := hex.EncodeToString(h.Sum(nil)); res != newtonSHA256 {
91		t.Error("retrieved data hash did not match")
92	}
93
94	for err := range errc {
95		t.Error(err)
96	}
97}
98
99func TestSendfileParts(t *testing.T) {
100	ln, err := newLocalListener("tcp")
101	if err != nil {
102		t.Fatal(err)
103	}
104	defer ln.Close()
105
106	errc := make(chan error, 1)
107	go func(ln Listener) {
108		// Wait for a connection.
109		conn, err := ln.Accept()
110		if err != nil {
111			errc <- err
112			close(errc)
113			return
114		}
115
116		go func() {
117			defer close(errc)
118			defer conn.Close()
119
120			f, err := os.Open(newton)
121			if err != nil {
122				errc <- err
123				return
124			}
125			defer f.Close()
126
127			for i := 0; i < 3; i++ {
128				// Return file data using io.CopyN, which should use
129				// sendFile if available.
130				_, err = io.CopyN(conn, f, 3)
131				if err != nil {
132					errc <- err
133					return
134				}
135			}
136		}()
137	}(ln)
138
139	c, err := Dial("tcp", ln.Addr().String())
140	if err != nil {
141		t.Fatal(err)
142	}
143	defer c.Close()
144
145	buf := new(bytes.Buffer)
146	buf.ReadFrom(c)
147
148	if want, have := "Produced ", buf.String(); have != want {
149		t.Errorf("unexpected server reply %q, want %q", have, want)
150	}
151
152	for err := range errc {
153		t.Error(err)
154	}
155}
156
157func TestSendfileSeeked(t *testing.T) {
158	ln, err := newLocalListener("tcp")
159	if err != nil {
160		t.Fatal(err)
161	}
162	defer ln.Close()
163
164	const seekTo = 65 << 10
165	const sendSize = 10 << 10
166
167	errc := make(chan error, 1)
168	go func(ln Listener) {
169		// Wait for a connection.
170		conn, err := ln.Accept()
171		if err != nil {
172			errc <- err
173			close(errc)
174			return
175		}
176
177		go func() {
178			defer close(errc)
179			defer conn.Close()
180
181			f, err := os.Open(newton)
182			if err != nil {
183				errc <- err
184				return
185			}
186			defer f.Close()
187			if _, err := f.Seek(seekTo, os.SEEK_SET); err != nil {
188				errc <- err
189				return
190			}
191
192			_, err = io.CopyN(conn, f, sendSize)
193			if err != nil {
194				errc <- err
195				return
196			}
197		}()
198	}(ln)
199
200	c, err := Dial("tcp", ln.Addr().String())
201	if err != nil {
202		t.Fatal(err)
203	}
204	defer c.Close()
205
206	buf := new(bytes.Buffer)
207	buf.ReadFrom(c)
208
209	if buf.Len() != sendSize {
210		t.Errorf("Got %d bytes; want %d", buf.Len(), sendSize)
211	}
212
213	for err := range errc {
214		t.Error(err)
215	}
216}
217
218// Test that sendfile doesn't put a pipe into blocking mode.
219func TestSendfilePipe(t *testing.T) {
220	switch runtime.GOOS {
221	case "plan9", "windows":
222		// These systems don't support deadlines on pipes.
223		t.Skipf("skipping on %s", runtime.GOOS)
224	}
225
226	t.Parallel()
227
228	ln, err := newLocalListener("tcp")
229	if err != nil {
230		t.Fatal(err)
231	}
232	defer ln.Close()
233
234	r, w, err := os.Pipe()
235	if err != nil {
236		t.Fatal(err)
237	}
238	defer w.Close()
239	defer r.Close()
240
241	copied := make(chan bool)
242
243	var wg sync.WaitGroup
244	wg.Add(1)
245	go func() {
246		// Accept a connection and copy 1 byte from the read end of
247		// the pipe to the connection. This will call into sendfile.
248		defer wg.Done()
249		conn, err := ln.Accept()
250		if err != nil {
251			t.Error(err)
252			return
253		}
254		defer conn.Close()
255		_, err = io.CopyN(conn, r, 1)
256		if err != nil {
257			t.Error(err)
258			return
259		}
260		// Signal the main goroutine that we've copied the byte.
261		close(copied)
262	}()
263
264	wg.Add(1)
265	go func() {
266		// Write 1 byte to the write end of the pipe.
267		defer wg.Done()
268		_, err := w.Write([]byte{'a'})
269		if err != nil {
270			t.Error(err)
271		}
272	}()
273
274	wg.Add(1)
275	go func() {
276		// Connect to the server started two goroutines up and
277		// discard any data that it writes.
278		defer wg.Done()
279		conn, err := Dial("tcp", ln.Addr().String())
280		if err != nil {
281			t.Error(err)
282			return
283		}
284		defer conn.Close()
285		io.Copy(ioutil.Discard, conn)
286	}()
287
288	// Wait for the byte to be copied, meaning that sendfile has
289	// been called on the pipe.
290	<-copied
291
292	// Set a very short deadline on the read end of the pipe.
293	if err := r.SetDeadline(time.Now().Add(time.Microsecond)); err != nil {
294		t.Fatal(err)
295	}
296
297	wg.Add(1)
298	go func() {
299		// Wait for much longer than the deadline and write a byte
300		// to the pipe.
301		defer wg.Done()
302		time.Sleep(50 * time.Millisecond)
303		w.Write([]byte{'b'})
304	}()
305
306	// If this read does not time out, the pipe was incorrectly
307	// put into blocking mode.
308	_, err = r.Read(make([]byte, 1))
309	if err == nil {
310		t.Error("Read did not time out")
311	} else if !os.IsTimeout(err) {
312		t.Errorf("got error %v, expected a time out", err)
313	}
314
315	wg.Wait()
316}
317