1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package net
6
7import (
8	"bytes"
9	"fmt"
10	"internal/poll"
11	"io"
12	"io/ioutil"
13	"reflect"
14	"runtime"
15	"sync"
16	"testing"
17)
18
19func TestBuffers_read(t *testing.T) {
20	const story = "once upon a time in Gopherland ... "
21	buffers := Buffers{
22		[]byte("once "),
23		[]byte("upon "),
24		[]byte("a "),
25		[]byte("time "),
26		[]byte("in "),
27		[]byte("Gopherland ... "),
28	}
29	got, err := ioutil.ReadAll(&buffers)
30	if err != nil {
31		t.Fatal(err)
32	}
33	if string(got) != story {
34		t.Errorf("read %q; want %q", got, story)
35	}
36	if len(buffers) != 0 {
37		t.Errorf("len(buffers) = %d; want 0", len(buffers))
38	}
39}
40
41func TestBuffers_consume(t *testing.T) {
42	tests := []struct {
43		in      Buffers
44		consume int64
45		want    Buffers
46	}{
47		{
48			in:      Buffers{[]byte("foo"), []byte("bar")},
49			consume: 0,
50			want:    Buffers{[]byte("foo"), []byte("bar")},
51		},
52		{
53			in:      Buffers{[]byte("foo"), []byte("bar")},
54			consume: 2,
55			want:    Buffers{[]byte("o"), []byte("bar")},
56		},
57		{
58			in:      Buffers{[]byte("foo"), []byte("bar")},
59			consume: 3,
60			want:    Buffers{[]byte("bar")},
61		},
62		{
63			in:      Buffers{[]byte("foo"), []byte("bar")},
64			consume: 4,
65			want:    Buffers{[]byte("ar")},
66		},
67		{
68			in:      Buffers{nil, nil, nil, []byte("bar")},
69			consume: 1,
70			want:    Buffers{[]byte("ar")},
71		},
72		{
73			in:      Buffers{nil, nil, nil, []byte("foo")},
74			consume: 0,
75			want:    Buffers{[]byte("foo")},
76		},
77		{
78			in:      Buffers{nil, nil, nil},
79			consume: 0,
80			want:    Buffers{},
81		},
82	}
83	for i, tt := range tests {
84		in := tt.in
85		in.consume(tt.consume)
86		if !reflect.DeepEqual(in, tt.want) {
87			t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
88		}
89	}
90}
91
92func TestBuffers_WriteTo(t *testing.T) {
93	for _, name := range []string{"WriteTo", "Copy"} {
94		for _, size := range []int{0, 10, 1023, 1024, 1025} {
95			t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
96				testBuffer_writeTo(t, size, name == "Copy")
97			})
98		}
99	}
100}
101
102func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
103	oldHook := poll.TestHookDidWritev
104	defer func() { poll.TestHookDidWritev = oldHook }()
105	var writeLog struct {
106		sync.Mutex
107		log []int
108	}
109	poll.TestHookDidWritev = func(size int) {
110		writeLog.Lock()
111		writeLog.log = append(writeLog.log, size)
112		writeLog.Unlock()
113	}
114	var want bytes.Buffer
115	for i := 0; i < chunks; i++ {
116		want.WriteByte(byte(i))
117	}
118
119	withTCPConnPair(t, func(c *TCPConn) error {
120		buffers := make(Buffers, chunks)
121		for i := range buffers {
122			buffers[i] = want.Bytes()[i : i+1]
123		}
124		var n int64
125		var err error
126		if useCopy {
127			n, err = io.Copy(c, &buffers)
128		} else {
129			n, err = buffers.WriteTo(c)
130		}
131		if err != nil {
132			return err
133		}
134		if len(buffers) != 0 {
135			return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
136		}
137		if n != int64(want.Len()) {
138			return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
139		}
140		return nil
141	}, func(c *TCPConn) error {
142		all, err := ioutil.ReadAll(c)
143		if !bytes.Equal(all, want.Bytes()) || err != nil {
144			return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
145		}
146
147		writeLog.Lock() // no need to unlock
148		var gotSum int
149		for _, v := range writeLog.log {
150			gotSum += v
151		}
152
153		var wantSum int
154		switch runtime.GOOS {
155		case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
156			var wantMinCalls int
157			wantSum = want.Len()
158			v := chunks
159			for v > 0 {
160				wantMinCalls++
161				v -= 1024
162			}
163			if len(writeLog.log) < wantMinCalls {
164				t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
165			}
166		case "windows":
167			var wantCalls int
168			wantSum = want.Len()
169			if wantSum > 0 {
170				wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
171			}
172			if len(writeLog.log) != wantCalls {
173				t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
174			}
175		}
176		if gotSum != wantSum {
177			t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
178		}
179		return nil
180	})
181}
182
183func TestWritevError(t *testing.T) {
184	if runtime.GOOS == "windows" {
185		t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
186	}
187
188	ln, err := newLocalListener("tcp")
189	if err != nil {
190		t.Fatal(err)
191	}
192	defer ln.Close()
193
194	ch := make(chan Conn, 1)
195	go func() {
196		defer close(ch)
197		c, err := ln.Accept()
198		if err != nil {
199			t.Error(err)
200			return
201		}
202		ch <- c
203	}()
204	c1, err := Dial("tcp", ln.Addr().String())
205	if err != nil {
206		t.Fatal(err)
207	}
208	defer c1.Close()
209	c2 := <-ch
210	if c2 == nil {
211		t.Fatal("no server side connection")
212	}
213	c2.Close()
214
215	// 1 GB of data should be enough to notice the connection is gone.
216	// Just a few bytes is not enough.
217	// Arrange to reuse the same 1 MB buffer so that we don't allocate much.
218	buf := make([]byte, 1<<20)
219	buffers := make(Buffers, 1<<10)
220	for i := range buffers {
221		buffers[i] = buf
222	}
223	if _, err := buffers.WriteTo(c1); err == nil {
224		t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
225	}
226}
227