1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package httptest
6
7import (
8	"fmt"
9	"io"
10	"net/http"
11	"testing"
12)
13
14func TestRecorder(t *testing.T) {
15	type checkFunc func(*ResponseRecorder) error
16	check := func(fns ...checkFunc) []checkFunc { return fns }
17
18	hasStatus := func(wantCode int) checkFunc {
19		return func(rec *ResponseRecorder) error {
20			if rec.Code != wantCode {
21				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
22			}
23			return nil
24		}
25	}
26	hasResultStatus := func(want string) checkFunc {
27		return func(rec *ResponseRecorder) error {
28			if rec.Result().Status != want {
29				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
30			}
31			return nil
32		}
33	}
34	hasResultStatusCode := func(wantCode int) checkFunc {
35		return func(rec *ResponseRecorder) error {
36			if rec.Result().StatusCode != wantCode {
37				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
38			}
39			return nil
40		}
41	}
42	hasResultContents := func(want string) checkFunc {
43		return func(rec *ResponseRecorder) error {
44			contentBytes, err := io.ReadAll(rec.Result().Body)
45			if err != nil {
46				return err
47			}
48			contents := string(contentBytes)
49			if contents != want {
50				return fmt.Errorf("Result().Body = %s; want %s", contents, want)
51			}
52			return nil
53		}
54	}
55	hasContents := func(want string) checkFunc {
56		return func(rec *ResponseRecorder) error {
57			if rec.Body.String() != want {
58				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
59			}
60			return nil
61		}
62	}
63	hasFlush := func(want bool) checkFunc {
64		return func(rec *ResponseRecorder) error {
65			if rec.Flushed != want {
66				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
67			}
68			return nil
69		}
70	}
71	hasOldHeader := func(key, want string) checkFunc {
72		return func(rec *ResponseRecorder) error {
73			if got := rec.HeaderMap.Get(key); got != want {
74				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
75			}
76			return nil
77		}
78	}
79	hasHeader := func(key, want string) checkFunc {
80		return func(rec *ResponseRecorder) error {
81			if got := rec.Result().Header.Get(key); got != want {
82				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
83			}
84			return nil
85		}
86	}
87	hasNotHeaders := func(keys ...string) checkFunc {
88		return func(rec *ResponseRecorder) error {
89			for _, k := range keys {
90				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
91				if ok {
92					return fmt.Errorf("unexpected header %s with value %q", k, v)
93				}
94			}
95			return nil
96		}
97	}
98	hasTrailer := func(key, want string) checkFunc {
99		return func(rec *ResponseRecorder) error {
100			if got := rec.Result().Trailer.Get(key); got != want {
101				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
102			}
103			return nil
104		}
105	}
106	hasNotTrailers := func(keys ...string) checkFunc {
107		return func(rec *ResponseRecorder) error {
108			trailers := rec.Result().Trailer
109			for _, k := range keys {
110				_, ok := trailers[http.CanonicalHeaderKey(k)]
111				if ok {
112					return fmt.Errorf("unexpected trailer %s", k)
113				}
114			}
115			return nil
116		}
117	}
118	hasContentLength := func(length int64) checkFunc {
119		return func(rec *ResponseRecorder) error {
120			if got := rec.Result().ContentLength; got != length {
121				return fmt.Errorf("ContentLength = %d; want %d", got, length)
122			}
123			return nil
124		}
125	}
126
127	for _, tt := range [...]struct {
128		name   string
129		h      func(w http.ResponseWriter, r *http.Request)
130		checks []checkFunc
131	}{
132		{
133			"200 default",
134			func(w http.ResponseWriter, r *http.Request) {},
135			check(hasStatus(200), hasContents("")),
136		},
137		{
138			"first code only",
139			func(w http.ResponseWriter, r *http.Request) {
140				w.WriteHeader(201)
141				w.WriteHeader(202)
142				w.Write([]byte("hi"))
143			},
144			check(hasStatus(201), hasContents("hi")),
145		},
146		{
147			"write sends 200",
148			func(w http.ResponseWriter, r *http.Request) {
149				w.Write([]byte("hi first"))
150				w.WriteHeader(201)
151				w.WriteHeader(202)
152			},
153			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
154		},
155		{
156			"write string",
157			func(w http.ResponseWriter, r *http.Request) {
158				io.WriteString(w, "hi first")
159			},
160			check(
161				hasStatus(200),
162				hasContents("hi first"),
163				hasFlush(false),
164				hasHeader("Content-Type", "text/plain; charset=utf-8"),
165			),
166		},
167		{
168			"flush",
169			func(w http.ResponseWriter, r *http.Request) {
170				w.(http.Flusher).Flush() // also sends a 200
171				w.WriteHeader(201)
172			},
173			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
174		},
175		{
176			"Content-Type detection",
177			func(w http.ResponseWriter, r *http.Request) {
178				io.WriteString(w, "<html>")
179			},
180			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
181		},
182		{
183			"no Content-Type detection with Transfer-Encoding",
184			func(w http.ResponseWriter, r *http.Request) {
185				w.Header().Set("Transfer-Encoding", "some encoding")
186				io.WriteString(w, "<html>")
187			},
188			check(hasHeader("Content-Type", "")), // no header
189		},
190		{
191			"no Content-Type detection if set explicitly",
192			func(w http.ResponseWriter, r *http.Request) {
193				w.Header().Set("Content-Type", "some/type")
194				io.WriteString(w, "<html>")
195			},
196			check(hasHeader("Content-Type", "some/type")),
197		},
198		{
199			"Content-Type detection doesn't crash if HeaderMap is nil",
200			func(w http.ResponseWriter, r *http.Request) {
201				// Act as if the user wrote new(httptest.ResponseRecorder)
202				// rather than using NewRecorder (which initializes
203				// HeaderMap)
204				w.(*ResponseRecorder).HeaderMap = nil
205				io.WriteString(w, "<html>")
206			},
207			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
208		},
209		{
210			"Header is not changed after write",
211			func(w http.ResponseWriter, r *http.Request) {
212				hdr := w.Header()
213				hdr.Set("Key", "correct")
214				w.WriteHeader(200)
215				hdr.Set("Key", "incorrect")
216			},
217			check(hasHeader("Key", "correct")),
218		},
219		{
220			"Trailer headers are correctly recorded",
221			func(w http.ResponseWriter, r *http.Request) {
222				w.Header().Set("Non-Trailer", "correct")
223				w.Header().Set("Trailer", "Trailer-A")
224				w.Header().Add("Trailer", "Trailer-B")
225				w.Header().Add("Trailer", "Trailer-C")
226				io.WriteString(w, "<html>")
227				w.Header().Set("Non-Trailer", "incorrect")
228				w.Header().Set("Trailer-A", "valuea")
229				w.Header().Set("Trailer-C", "valuec")
230				w.Header().Set("Trailer-NotDeclared", "should be omitted")
231				w.Header().Set("Trailer:Trailer-D", "with prefix")
232			},
233			check(
234				hasStatus(200),
235				hasHeader("Content-Type", "text/html; charset=utf-8"),
236				hasHeader("Non-Trailer", "correct"),
237				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
238				hasTrailer("Trailer-A", "valuea"),
239				hasTrailer("Trailer-C", "valuec"),
240				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
241				hasTrailer("Trailer-D", "with prefix"),
242			),
243		},
244		{
245			"Header set without any write", // Issue 15560
246			func(w http.ResponseWriter, r *http.Request) {
247				w.Header().Set("X-Foo", "1")
248
249				// Simulate somebody using
250				// new(ResponseRecorder) instead of
251				// using the constructor which sets
252				// this to 200
253				w.(*ResponseRecorder).Code = 0
254			},
255			check(
256				hasOldHeader("X-Foo", "1"),
257				hasStatus(0),
258				hasHeader("X-Foo", "1"),
259				hasResultStatus("200 OK"),
260				hasResultStatusCode(200),
261			),
262		},
263		{
264			"HeaderMap vs FinalHeaders", // more for Issue 15560
265			func(w http.ResponseWriter, r *http.Request) {
266				h := w.Header()
267				h.Set("X-Foo", "1")
268				w.Write([]byte("hi"))
269				h.Set("X-Foo", "2")
270				h.Set("X-Bar", "2")
271			},
272			check(
273				hasOldHeader("X-Foo", "2"),
274				hasOldHeader("X-Bar", "2"),
275				hasHeader("X-Foo", "1"),
276				hasNotHeaders("X-Bar"),
277			),
278		},
279		{
280			"setting Content-Length header",
281			func(w http.ResponseWriter, r *http.Request) {
282				body := "Some body"
283				contentLength := fmt.Sprintf("%d", len(body))
284				w.Header().Set("Content-Length", contentLength)
285				io.WriteString(w, body)
286			},
287			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
288		},
289		{
290			"nil ResponseRecorder.Body", // Issue 26642
291			func(w http.ResponseWriter, r *http.Request) {
292				w.(*ResponseRecorder).Body = nil
293				io.WriteString(w, "hi")
294			},
295			check(hasResultContents("")), // check we don't crash reading the body
296
297		},
298	} {
299		t.Run(tt.name, func(t *testing.T) {
300			r, _ := http.NewRequest("GET", "http://foo.com/", nil)
301			h := http.HandlerFunc(tt.h)
302			rec := NewRecorder()
303			h.ServeHTTP(rec, r)
304			for _, check := range tt.checks {
305				if err := check(rec); err != nil {
306					t.Error(err)
307				}
308			}
309		})
310	}
311}
312
313// issue 39017 - disallow Content-Length values such as "+3"
314func TestParseContentLength(t *testing.T) {
315	tests := []struct {
316		cl   string
317		want int64
318	}{
319		{
320			cl:   "3",
321			want: 3,
322		},
323		{
324			cl:   "+3",
325			want: -1,
326		},
327		{
328			cl:   "-3",
329			want: -1,
330		},
331		{
332			// max int64, for safe conversion before returning
333			cl:   "9223372036854775807",
334			want: 9223372036854775807,
335		},
336		{
337			cl:   "9223372036854775808",
338			want: -1,
339		},
340	}
341
342	for _, tt := range tests {
343		if got := parseContentLength(tt.cl); got != tt.want {
344			t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
345		}
346	}
347}
348