1// Copyright 2013 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package netutil
6
7import (
8	"errors"
9	"fmt"
10	"io"
11	"io/ioutil"
12	"net"
13	"net/http"
14	"sync"
15	"sync/atomic"
16	"testing"
17	"time"
18
19	"golang.org/x/net/internal/nettest"
20)
21
22func TestLimitListener(t *testing.T) {
23	const max = 5
24	attempts := (nettest.MaxOpenFiles() - max) / 2
25	if attempts > 256 { // maximum length of accept queue is 128 by default
26		attempts = 256
27	}
28
29	l, err := net.Listen("tcp", "127.0.0.1:0")
30	if err != nil {
31		t.Fatal(err)
32	}
33	defer l.Close()
34	l = LimitListener(l, max)
35
36	var open int32
37	go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38		if n := atomic.AddInt32(&open, 1); n > max {
39			t.Errorf("%d open connections, want <= %d", n, max)
40		}
41		defer atomic.AddInt32(&open, -1)
42		time.Sleep(10 * time.Millisecond)
43		fmt.Fprint(w, "some body")
44	}))
45
46	var wg sync.WaitGroup
47	var failed int32
48	for i := 0; i < attempts; i++ {
49		wg.Add(1)
50		go func() {
51			defer wg.Done()
52			c := http.Client{Timeout: 3 * time.Second}
53			r, err := c.Get("http://" + l.Addr().String())
54			if err != nil {
55				t.Log(err)
56				atomic.AddInt32(&failed, 1)
57				return
58			}
59			defer r.Body.Close()
60			io.Copy(ioutil.Discard, r.Body)
61		}()
62	}
63	wg.Wait()
64
65	// We expect some Gets to fail as the kernel's accept queue is filled,
66	// but most should succeed.
67	if int(failed) >= attempts/2 {
68		t.Errorf("%d requests failed within %d attempts", failed, attempts)
69	}
70}
71
72type errorListener struct {
73	net.Listener
74}
75
76func (errorListener) Accept() (net.Conn, error) {
77	return nil, errFake
78}
79
80var errFake = errors.New("fake error from errorListener")
81
82// This used to hang.
83func TestLimitListenerError(t *testing.T) {
84	donec := make(chan bool, 1)
85	go func() {
86		const n = 2
87		ll := LimitListener(errorListener{}, n)
88		for i := 0; i < n+1; i++ {
89			_, err := ll.Accept()
90			if err != errFake {
91				t.Fatalf("Accept error = %v; want errFake", err)
92			}
93		}
94		donec <- true
95	}()
96	select {
97	case <-donec:
98	case <-time.After(5 * time.Second):
99		t.Fatal("timeout. deadlock?")
100	}
101}
102
103func TestLimitListenerClose(t *testing.T) {
104	ln, err := net.Listen("tcp", "127.0.0.1:0")
105	if err != nil {
106		t.Fatal(err)
107	}
108	defer ln.Close()
109	ln = LimitListener(ln, 1)
110
111	doneCh := make(chan struct{})
112	defer close(doneCh)
113	go func() {
114		c, err := net.Dial("tcp", ln.Addr().String())
115		if err != nil {
116			t.Fatal(err)
117		}
118		defer c.Close()
119		<-doneCh
120	}()
121
122	c, err := ln.Accept()
123	if err != nil {
124		t.Fatal(err)
125	}
126	defer c.Close()
127
128	acceptDone := make(chan struct{})
129	go func() {
130		c, err := ln.Accept()
131		if err == nil {
132			c.Close()
133			t.Errorf("Unexpected successful Accept()")
134		}
135		close(acceptDone)
136	}()
137
138	// Wait a tiny bit to ensure the Accept() is blocking.
139	time.Sleep(10 * time.Millisecond)
140	ln.Close()
141
142	select {
143	case <-acceptDone:
144	case <-time.After(5 * time.Second):
145		t.Fatalf("Accept() still blocking")
146	}
147}
148