1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build !js
6
7package net
8
9import (
10	"bytes"
11	"fmt"
12	"internal/poll"
13	"io"
14	"io/ioutil"
15	"reflect"
16	"runtime"
17	"sync"
18	"testing"
19)
20
21func TestBuffers_read(t *testing.T) {
22	const story = "once upon a time in Gopherland ... "
23	buffers := Buffers{
24		[]byte("once "),
25		[]byte("upon "),
26		[]byte("a "),
27		[]byte("time "),
28		[]byte("in "),
29		[]byte("Gopherland ... "),
30	}
31	got, err := ioutil.ReadAll(&buffers)
32	if err != nil {
33		t.Fatal(err)
34	}
35	if string(got) != story {
36		t.Errorf("read %q; want %q", got, story)
37	}
38	if len(buffers) != 0 {
39		t.Errorf("len(buffers) = %d; want 0", len(buffers))
40	}
41}
42
43func TestBuffers_consume(t *testing.T) {
44	tests := []struct {
45		in      Buffers
46		consume int64
47		want    Buffers
48	}{
49		{
50			in:      Buffers{[]byte("foo"), []byte("bar")},
51			consume: 0,
52			want:    Buffers{[]byte("foo"), []byte("bar")},
53		},
54		{
55			in:      Buffers{[]byte("foo"), []byte("bar")},
56			consume: 2,
57			want:    Buffers{[]byte("o"), []byte("bar")},
58		},
59		{
60			in:      Buffers{[]byte("foo"), []byte("bar")},
61			consume: 3,
62			want:    Buffers{[]byte("bar")},
63		},
64		{
65			in:      Buffers{[]byte("foo"), []byte("bar")},
66			consume: 4,
67			want:    Buffers{[]byte("ar")},
68		},
69		{
70			in:      Buffers{nil, nil, nil, []byte("bar")},
71			consume: 1,
72			want:    Buffers{[]byte("ar")},
73		},
74		{
75			in:      Buffers{nil, nil, nil, []byte("foo")},
76			consume: 0,
77			want:    Buffers{[]byte("foo")},
78		},
79		{
80			in:      Buffers{nil, nil, nil},
81			consume: 0,
82			want:    Buffers{},
83		},
84	}
85	for i, tt := range tests {
86		in := tt.in
87		in.consume(tt.consume)
88		if !reflect.DeepEqual(in, tt.want) {
89			t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
90		}
91	}
92}
93
94func TestBuffers_WriteTo(t *testing.T) {
95	for _, name := range []string{"WriteTo", "Copy"} {
96		for _, size := range []int{0, 10, 1023, 1024, 1025} {
97			t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
98				testBuffer_writeTo(t, size, name == "Copy")
99			})
100		}
101	}
102}
103
104func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
105	oldHook := poll.TestHookDidWritev
106	defer func() { poll.TestHookDidWritev = oldHook }()
107	var writeLog struct {
108		sync.Mutex
109		log []int
110	}
111	poll.TestHookDidWritev = func(size int) {
112		writeLog.Lock()
113		writeLog.log = append(writeLog.log, size)
114		writeLog.Unlock()
115	}
116	var want bytes.Buffer
117	for i := 0; i < chunks; i++ {
118		want.WriteByte(byte(i))
119	}
120
121	withTCPConnPair(t, func(c *TCPConn) error {
122		buffers := make(Buffers, chunks)
123		for i := range buffers {
124			buffers[i] = want.Bytes()[i : i+1]
125		}
126		var n int64
127		var err error
128		if useCopy {
129			n, err = io.Copy(c, &buffers)
130		} else {
131			n, err = buffers.WriteTo(c)
132		}
133		if err != nil {
134			return err
135		}
136		if len(buffers) != 0 {
137			return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
138		}
139		if n != int64(want.Len()) {
140			return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
141		}
142		return nil
143	}, func(c *TCPConn) error {
144		all, err := ioutil.ReadAll(c)
145		if !bytes.Equal(all, want.Bytes()) || err != nil {
146			return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
147		}
148
149		writeLog.Lock() // no need to unlock
150		var gotSum int
151		for _, v := range writeLog.log {
152			gotSum += v
153		}
154
155		var wantSum int
156		switch runtime.GOOS {
157		case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
158			var wantMinCalls int
159			wantSum = want.Len()
160			v := chunks
161			for v > 0 {
162				wantMinCalls++
163				v -= 1024
164			}
165			if len(writeLog.log) < wantMinCalls {
166				t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
167			}
168		case "windows":
169			var wantCalls int
170			wantSum = want.Len()
171			if wantSum > 0 {
172				wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
173			}
174			if len(writeLog.log) != wantCalls {
175				t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
176			}
177		}
178		if gotSum != wantSum {
179			t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
180		}
181		return nil
182	})
183}
184
185func TestWritevError(t *testing.T) {
186	if runtime.GOOS == "windows" {
187		t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
188	}
189
190	ln, err := newLocalListener("tcp")
191	if err != nil {
192		t.Fatal(err)
193	}
194	defer ln.Close()
195
196	ch := make(chan Conn, 1)
197	go func() {
198		defer close(ch)
199		c, err := ln.Accept()
200		if err != nil {
201			t.Error(err)
202			return
203		}
204		ch <- c
205	}()
206	c1, err := Dial("tcp", ln.Addr().String())
207	if err != nil {
208		t.Fatal(err)
209	}
210	defer c1.Close()
211	c2 := <-ch
212	if c2 == nil {
213		t.Fatal("no server side connection")
214	}
215	c2.Close()
216
217	// 1 GB of data should be enough to notice the connection is gone.
218	// Just a few bytes is not enough.
219	// Arrange to reuse the same 1 MB buffer so that we don't allocate much.
220	buf := make([]byte, 1<<20)
221	buffers := make(Buffers, 1<<10)
222	for i := range buffers {
223		buffers[i] = buf
224	}
225	if _, err := buffers.WriteTo(c1); err == nil {
226		t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
227	}
228}
229