1package http_test
2
3import (
4	"context"
5	"errors"
6	"io/ioutil"
7	"net/http"
8	"net/http/httptest"
9	"strings"
10	"testing"
11	"time"
12
13	"github.com/go-kit/kit/endpoint"
14	httptransport "github.com/go-kit/kit/transport/http"
15)
16
17func TestServerBadDecode(t *testing.T) {
18	handler := httptransport.NewServer(
19		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
20		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") },
21		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
22	)
23	server := httptest.NewServer(handler)
24	defer server.Close()
25	resp, _ := http.Get(server.URL)
26	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
27		t.Errorf("want %d, have %d", want, have)
28	}
29}
30
31func TestServerBadEndpoint(t *testing.T) {
32	handler := httptransport.NewServer(
33		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") },
34		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
35		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
36	)
37	server := httptest.NewServer(handler)
38	defer server.Close()
39	resp, _ := http.Get(server.URL)
40	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
41		t.Errorf("want %d, have %d", want, have)
42	}
43}
44
45func TestServerBadEncode(t *testing.T) {
46	handler := httptransport.NewServer(
47		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
48		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
49		func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") },
50	)
51	server := httptest.NewServer(handler)
52	defer server.Close()
53	resp, _ := http.Get(server.URL)
54	if want, have := http.StatusInternalServerError, resp.StatusCode; want != have {
55		t.Errorf("want %d, have %d", want, have)
56	}
57}
58
59func TestServerErrorEncoder(t *testing.T) {
60	errTeapot := errors.New("teapot")
61	code := func(err error) int {
62		if err == errTeapot {
63			return http.StatusTeapot
64		}
65		return http.StatusInternalServerError
66	}
67	handler := httptransport.NewServer(
68		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
69		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
70		func(context.Context, http.ResponseWriter, interface{}) error { return nil },
71		httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }),
72	)
73	server := httptest.NewServer(handler)
74	defer server.Close()
75	resp, _ := http.Get(server.URL)
76	if want, have := http.StatusTeapot, resp.StatusCode; want != have {
77		t.Errorf("want %d, have %d", want, have)
78	}
79}
80
81func TestServerHappyPath(t *testing.T) {
82	step, response := testServer(t)
83	step()
84	resp := <-response
85	defer resp.Body.Close()
86	buf, _ := ioutil.ReadAll(resp.Body)
87	if want, have := http.StatusOK, resp.StatusCode; want != have {
88		t.Errorf("want %d, have %d (%s)", want, have, buf)
89	}
90}
91
92func TestMultipleServerBefore(t *testing.T) {
93	var (
94		headerKey    = "X-Henlo-Lizer"
95		headerVal    = "Helllo you stinky lizard"
96		statusCode   = http.StatusTeapot
97		responseBody = "go eat a fly ugly\n"
98		done         = make(chan struct{})
99	)
100	handler := httptransport.NewServer(
101		endpoint.Nop,
102		func(context.Context, *http.Request) (interface{}, error) {
103			return struct{}{}, nil
104		},
105		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
106			w.Header().Set(headerKey, headerVal)
107			w.WriteHeader(statusCode)
108			w.Write([]byte(responseBody))
109			return nil
110		},
111		httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
112			ctx = context.WithValue(ctx, "one", 1)
113
114			return ctx
115		}),
116		httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context {
117			if _, ok := ctx.Value("one").(int); !ok {
118				t.Error("Value was not set properly when multiple ServerBefores are used")
119			}
120
121			close(done)
122			return ctx
123		}),
124	)
125
126	server := httptest.NewServer(handler)
127	defer server.Close()
128	go http.Get(server.URL)
129
130	select {
131	case <-done:
132	case <-time.After(time.Second):
133		t.Fatal("timeout waiting for finalizer")
134	}
135}
136
137func TestMultipleServerAfter(t *testing.T) {
138	var (
139		headerKey    = "X-Henlo-Lizer"
140		headerVal    = "Helllo you stinky lizard"
141		statusCode   = http.StatusTeapot
142		responseBody = "go eat a fly ugly\n"
143		done         = make(chan struct{})
144	)
145	handler := httptransport.NewServer(
146		endpoint.Nop,
147		func(context.Context, *http.Request) (interface{}, error) {
148			return struct{}{}, nil
149		},
150		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
151			w.Header().Set(headerKey, headerVal)
152			w.WriteHeader(statusCode)
153			w.Write([]byte(responseBody))
154			return nil
155		},
156		httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
157			ctx = context.WithValue(ctx, "one", 1)
158
159			return ctx
160		}),
161		httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context {
162			if _, ok := ctx.Value("one").(int); !ok {
163				t.Error("Value was not set properly when multiple ServerAfters are used")
164			}
165
166			close(done)
167			return ctx
168		}),
169	)
170
171	server := httptest.NewServer(handler)
172	defer server.Close()
173	go http.Get(server.URL)
174
175	select {
176	case <-done:
177	case <-time.After(time.Second):
178		t.Fatal("timeout waiting for finalizer")
179	}
180}
181
182func TestServerFinalizer(t *testing.T) {
183	var (
184		headerKey    = "X-Henlo-Lizer"
185		headerVal    = "Helllo you stinky lizard"
186		statusCode   = http.StatusTeapot
187		responseBody = "go eat a fly ugly\n"
188		done         = make(chan struct{})
189	)
190	handler := httptransport.NewServer(
191		endpoint.Nop,
192		func(context.Context, *http.Request) (interface{}, error) {
193			return struct{}{}, nil
194		},
195		func(_ context.Context, w http.ResponseWriter, _ interface{}) error {
196			w.Header().Set(headerKey, headerVal)
197			w.WriteHeader(statusCode)
198			w.Write([]byte(responseBody))
199			return nil
200		},
201		httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) {
202			if want, have := statusCode, code; want != have {
203				t.Errorf("StatusCode: want %d, have %d", want, have)
204			}
205
206			responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header)
207			if want, have := headerVal, responseHeader.Get(headerKey); want != have {
208				t.Errorf("%s: want %q, have %q", headerKey, want, have)
209			}
210
211			responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64)
212			if want, have := int64(len(responseBody)), responseSize; want != have {
213				t.Errorf("response size: want %d, have %d", want, have)
214			}
215
216			close(done)
217		}),
218	)
219
220	server := httptest.NewServer(handler)
221	defer server.Close()
222	go http.Get(server.URL)
223
224	select {
225	case <-done:
226	case <-time.After(time.Second):
227		t.Fatal("timeout waiting for finalizer")
228	}
229}
230
231type enhancedResponse struct {
232	Foo string `json:"foo"`
233}
234
235func (e enhancedResponse) StatusCode() int      { return http.StatusPaymentRequired }
236func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }
237
238func TestEncodeJSONResponse(t *testing.T) {
239	handler := httptransport.NewServer(
240		func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
241		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
242		httptransport.EncodeJSONResponse,
243	)
244
245	server := httptest.NewServer(handler)
246	defer server.Close()
247
248	resp, err := http.Get(server.URL)
249	if err != nil {
250		t.Fatal(err)
251	}
252	if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
253		t.Errorf("StatusCode: want %d, have %d", want, have)
254	}
255	if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
256		t.Errorf("X-Edward: want %q, have %q", want, have)
257	}
258	buf, _ := ioutil.ReadAll(resp.Body)
259	if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
260		t.Errorf("Body: want %s, have %s", want, have)
261	}
262}
263
264type multiHeaderResponse struct{}
265
266func (_ multiHeaderResponse) Headers() http.Header {
267	return http.Header{"Vary": []string{"Origin", "User-Agent"}}
268}
269
270func TestAddMultipleHeaders(t *testing.T) {
271	handler := httptransport.NewServer(
272		func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil },
273		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
274		httptransport.EncodeJSONResponse,
275	)
276
277	server := httptest.NewServer(handler)
278	defer server.Close()
279
280	resp, err := http.Get(server.URL)
281	if err != nil {
282		t.Fatal(err)
283	}
284	expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
285	for k, vls := range resp.Header {
286		for _, v := range vls {
287			delete((expect[k]), v)
288		}
289		if len(expect[k]) != 0 {
290			t.Errorf("Header: unexpected header %s: %v", k, expect[k])
291		}
292	}
293}
294
295type multiHeaderResponseError struct {
296	multiHeaderResponse
297	msg string
298}
299
300func (m multiHeaderResponseError) Error() string {
301	return m.msg
302}
303
304func TestAddMultipleHeadersErrorEncoder(t *testing.T) {
305	errStr := "oh no"
306	handler := httptransport.NewServer(
307		func(context.Context, interface{}) (interface{}, error) {
308			return nil, multiHeaderResponseError{msg: errStr}
309		},
310		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
311		httptransport.EncodeJSONResponse,
312	)
313
314	server := httptest.NewServer(handler)
315	defer server.Close()
316
317	resp, err := http.Get(server.URL)
318	if err != nil {
319		t.Fatal(err)
320	}
321	expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
322	for k, vls := range resp.Header {
323		for _, v := range vls {
324			delete((expect[k]), v)
325		}
326		if len(expect[k]) != 0 {
327			t.Errorf("Header: unexpected header %s: %v", k, expect[k])
328		}
329	}
330	if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) {
331		t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr)
332	}
333}
334
335type noContentResponse struct{}
336
337func (e noContentResponse) StatusCode() int { return http.StatusNoContent }
338
339func TestEncodeNoContent(t *testing.T) {
340	handler := httptransport.NewServer(
341		func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
342		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
343		httptransport.EncodeJSONResponse,
344	)
345
346	server := httptest.NewServer(handler)
347	defer server.Close()
348
349	resp, err := http.Get(server.URL)
350	if err != nil {
351		t.Fatal(err)
352	}
353	if want, have := http.StatusNoContent, resp.StatusCode; want != have {
354		t.Errorf("StatusCode: want %d, have %d", want, have)
355	}
356	buf, _ := ioutil.ReadAll(resp.Body)
357	if want, have := 0, len(buf); want != have {
358		t.Errorf("Body: want no content, have %d bytes", have)
359	}
360}
361
362type enhancedError struct{}
363
364func (e enhancedError) Error() string                { return "enhanced error" }
365func (e enhancedError) StatusCode() int              { return http.StatusTeapot }
366func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
367func (e enhancedError) Headers() http.Header         { return http.Header{"X-Enhanced": []string{"1"}} }
368
369func TestEnhancedError(t *testing.T) {
370	handler := httptransport.NewServer(
371		func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
372		func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
373		func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
374	)
375
376	server := httptest.NewServer(handler)
377	defer server.Close()
378
379	resp, err := http.Get(server.URL)
380	if err != nil {
381		t.Fatal(err)
382	}
383	defer resp.Body.Close()
384	if want, have := http.StatusTeapot, resp.StatusCode; want != have {
385		t.Errorf("StatusCode: want %d, have %d", want, have)
386	}
387	if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
388		t.Errorf("X-Enhanced: want %q, have %q", want, have)
389	}
390	buf, _ := ioutil.ReadAll(resp.Body)
391	if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
392		t.Errorf("Body: want %s, have %s", want, have)
393	}
394}
395
396func TestNoOpRequestDecoder(t *testing.T) {
397	resw := httptest.NewRecorder()
398	req, err := http.NewRequest(http.MethodGet, "/", nil)
399	if err != nil {
400		t.Error("Failed to create request")
401	}
402	handler := httptransport.NewServer(
403		func(ctx context.Context, request interface{}) (interface{}, error) {
404			if request != nil {
405				t.Error("Expected nil request in endpoint when using NopRequestDecoder")
406			}
407			return nil, nil
408		},
409		httptransport.NopRequestDecoder,
410		httptransport.EncodeJSONResponse,
411	)
412	handler.ServeHTTP(resw, req)
413	if resw.Code != http.StatusOK {
414		t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code)
415	}
416}
417
418func testServer(t *testing.T) (step func(), resp <-chan *http.Response) {
419	var (
420		stepch   = make(chan bool)
421		endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil }
422		response = make(chan *http.Response)
423		handler  = httptransport.NewServer(
424			endpoint,
425			func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
426			func(context.Context, http.ResponseWriter, interface{}) error { return nil },
427			httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { return ctx }),
428			httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }),
429		)
430	)
431	go func() {
432		server := httptest.NewServer(handler)
433		defer server.Close()
434		resp, err := http.Get(server.URL)
435		if err != nil {
436			t.Error(err)
437			return
438		}
439		response <- resp
440	}()
441	return func() { stepch <- true }, response
442}
443