1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package httptest
6
7import (
8	"bufio"
9	"io/ioutil"
10	"net"
11	"net/http"
12	"testing"
13)
14
15type newServerFunc func(http.Handler) *Server
16
17var newServers = map[string]newServerFunc{
18	"NewServer":    NewServer,
19	"NewTLSServer": NewTLSServer,
20
21	// The manual variants of newServer create a Server manually by only filling
22	// in the exported fields of Server.
23	"NewServerManual": func(h http.Handler) *Server {
24		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
25		ts.Start()
26		return ts
27	},
28	"NewTLSServerManual": func(h http.Handler) *Server {
29		ts := &Server{Listener: newLocalListener(), Config: &http.Server{Handler: h}}
30		ts.StartTLS()
31		return ts
32	},
33}
34
35func TestServer(t *testing.T) {
36	for _, name := range []string{"NewServer", "NewServerManual"} {
37		t.Run(name, func(t *testing.T) {
38			newServer := newServers[name]
39			t.Run("Server", func(t *testing.T) { testServer(t, newServer) })
40			t.Run("GetAfterClose", func(t *testing.T) { testGetAfterClose(t, newServer) })
41			t.Run("ServerCloseBlocking", func(t *testing.T) { testServerCloseBlocking(t, newServer) })
42			t.Run("ServerCloseClientConnections", func(t *testing.T) { testServerCloseClientConnections(t, newServer) })
43			t.Run("ServerClientTransportType", func(t *testing.T) { testServerClientTransportType(t, newServer) })
44		})
45	}
46	for _, name := range []string{"NewTLSServer", "NewTLSServerManual"} {
47		t.Run(name, func(t *testing.T) {
48			newServer := newServers[name]
49			t.Run("ServerClient", func(t *testing.T) { testServerClient(t, newServer) })
50			t.Run("TLSServerClientTransportType", func(t *testing.T) { testTLSServerClientTransportType(t, newServer) })
51		})
52	}
53}
54
55func testServer(t *testing.T, newServer newServerFunc) {
56	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
57		w.Write([]byte("hello"))
58	}))
59	defer ts.Close()
60	res, err := http.Get(ts.URL)
61	if err != nil {
62		t.Fatal(err)
63	}
64	got, err := ioutil.ReadAll(res.Body)
65	res.Body.Close()
66	if err != nil {
67		t.Fatal(err)
68	}
69	if string(got) != "hello" {
70		t.Errorf("got %q, want hello", string(got))
71	}
72}
73
74// Issue 12781
75func testGetAfterClose(t *testing.T, newServer newServerFunc) {
76	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77		w.Write([]byte("hello"))
78	}))
79
80	res, err := http.Get(ts.URL)
81	if err != nil {
82		t.Fatal(err)
83	}
84	got, err := ioutil.ReadAll(res.Body)
85	if err != nil {
86		t.Fatal(err)
87	}
88	if string(got) != "hello" {
89		t.Fatalf("got %q, want hello", string(got))
90	}
91
92	ts.Close()
93
94	res, err = http.Get(ts.URL)
95	if err == nil {
96		body, _ := ioutil.ReadAll(res.Body)
97		t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body)
98	}
99}
100
101func testServerCloseBlocking(t *testing.T, newServer newServerFunc) {
102	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
103		w.Write([]byte("hello"))
104	}))
105	dial := func() net.Conn {
106		c, err := net.Dial("tcp", ts.Listener.Addr().String())
107		if err != nil {
108			t.Fatal(err)
109		}
110		return c
111	}
112
113	// Keep one connection in StateNew (connected, but not sending anything)
114	cnew := dial()
115	defer cnew.Close()
116
117	// Keep one connection in StateIdle (idle after a request)
118	cidle := dial()
119	defer cidle.Close()
120	cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
121	_, err := http.ReadResponse(bufio.NewReader(cidle), nil)
122	if err != nil {
123		t.Fatal(err)
124	}
125
126	ts.Close() // test we don't hang here forever.
127}
128
129// Issue 14290
130func testServerCloseClientConnections(t *testing.T, newServer newServerFunc) {
131	var s *Server
132	s = newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
133		s.CloseClientConnections()
134	}))
135	defer s.Close()
136	res, err := http.Get(s.URL)
137	if err == nil {
138		res.Body.Close()
139		t.Fatalf("Unexpected response: %#v", res)
140	}
141}
142
143// Tests that the Server.Client method works and returns an http.Client that can hit
144// NewTLSServer without cert warnings.
145func testServerClient(t *testing.T, newTLSServer newServerFunc) {
146	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
147		w.Write([]byte("hello"))
148	}))
149	defer ts.Close()
150	client := ts.Client()
151	res, err := client.Get(ts.URL)
152	if err != nil {
153		t.Fatal(err)
154	}
155	got, err := ioutil.ReadAll(res.Body)
156	res.Body.Close()
157	if err != nil {
158		t.Fatal(err)
159	}
160	if string(got) != "hello" {
161		t.Errorf("got %q, want hello", string(got))
162	}
163}
164
165// Tests that the Server.Client.Transport interface is implemented
166// by a *http.Transport.
167func testServerClientTransportType(t *testing.T, newServer newServerFunc) {
168	ts := newServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
169	}))
170	defer ts.Close()
171	client := ts.Client()
172	if _, ok := client.Transport.(*http.Transport); !ok {
173		t.Errorf("got %T, want *http.Transport", client.Transport)
174	}
175}
176
177// Tests that the TLS Server.Client.Transport interface is implemented
178// by a *http.Transport.
179func testTLSServerClientTransportType(t *testing.T, newTLSServer newServerFunc) {
180	ts := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
181	}))
182	defer ts.Close()
183	client := ts.Client()
184	if _, ok := client.Transport.(*http.Transport); !ok {
185		t.Errorf("got %T, want *http.Transport", client.Transport)
186	}
187}
188
189type onlyCloseListener struct {
190	net.Listener
191}
192
193func (onlyCloseListener) Close() error { return nil }
194
195// Issue 19729: panic in Server.Close for values created directly
196// without a constructor (so the unexported client field is nil).
197func TestServerZeroValueClose(t *testing.T) {
198	ts := &Server{
199		Listener: onlyCloseListener{},
200		Config:   &http.Server{},
201	}
202
203	ts.Close() // tests that it doesn't panic
204}
205
206func TestTLSServerWithHTTP2(t *testing.T) {
207	modes := []struct {
208		name      string
209		wantProto string
210	}{
211		{"http1", "HTTP/1.1"},
212		{"http2", "HTTP/2.0"},
213	}
214
215	for _, tt := range modes {
216		t.Run(tt.name, func(t *testing.T) {
217			cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
218				w.Header().Set("X-Proto", r.Proto)
219			}))
220
221			switch tt.name {
222			case "http2":
223				cst.EnableHTTP2 = true
224				cst.StartTLS()
225			default:
226				cst.Start()
227			}
228
229			defer cst.Close()
230
231			res, err := cst.Client().Get(cst.URL)
232			if err != nil {
233				t.Fatalf("Failed to make request: %v", err)
234			}
235			if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
236				t.Fatalf("X-Proto header mismatch:\n\tgot:  %q\n\twant: %q", g, w)
237			}
238		})
239	}
240}
241