1package http_test 2 3import ( 4 "context" 5 "errors" 6 "io/ioutil" 7 "net/http" 8 "net/http/httptest" 9 "strings" 10 "testing" 11 "time" 12 13 "github.com/go-kit/kit/endpoint" 14 httptransport "github.com/go-kit/kit/transport/http" 15) 16 17func TestServerBadDecode(t *testing.T) { 18 handler := httptransport.NewServer( 19 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, 20 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, errors.New("dang") }, 21 func(context.Context, http.ResponseWriter, interface{}) error { return nil }, 22 ) 23 server := httptest.NewServer(handler) 24 defer server.Close() 25 resp, _ := http.Get(server.URL) 26 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { 27 t.Errorf("want %d, have %d", want, have) 28 } 29} 30 31func TestServerBadEndpoint(t *testing.T) { 32 handler := httptransport.NewServer( 33 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("dang") }, 34 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 35 func(context.Context, http.ResponseWriter, interface{}) error { return nil }, 36 ) 37 server := httptest.NewServer(handler) 38 defer server.Close() 39 resp, _ := http.Get(server.URL) 40 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { 41 t.Errorf("want %d, have %d", want, have) 42 } 43} 44 45func TestServerBadEncode(t *testing.T) { 46 handler := httptransport.NewServer( 47 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, 48 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 49 func(context.Context, http.ResponseWriter, interface{}) error { return errors.New("dang") }, 50 ) 51 server := httptest.NewServer(handler) 52 defer server.Close() 53 resp, _ := http.Get(server.URL) 54 if want, have := http.StatusInternalServerError, resp.StatusCode; want != have { 55 t.Errorf("want %d, have %d", want, have) 56 } 57} 58 59func TestServerErrorEncoder(t *testing.T) { 60 errTeapot := errors.New("teapot") 61 code := func(err error) int { 62 if err == errTeapot { 63 return http.StatusTeapot 64 } 65 return http.StatusInternalServerError 66 } 67 handler := httptransport.NewServer( 68 func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, 69 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 70 func(context.Context, http.ResponseWriter, interface{}) error { return nil }, 71 httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), 72 ) 73 server := httptest.NewServer(handler) 74 defer server.Close() 75 resp, _ := http.Get(server.URL) 76 if want, have := http.StatusTeapot, resp.StatusCode; want != have { 77 t.Errorf("want %d, have %d", want, have) 78 } 79} 80 81func TestServerHappyPath(t *testing.T) { 82 step, response := testServer(t) 83 step() 84 resp := <-response 85 defer resp.Body.Close() 86 buf, _ := ioutil.ReadAll(resp.Body) 87 if want, have := http.StatusOK, resp.StatusCode; want != have { 88 t.Errorf("want %d, have %d (%s)", want, have, buf) 89 } 90} 91 92func TestMultipleServerBefore(t *testing.T) { 93 var ( 94 headerKey = "X-Henlo-Lizer" 95 headerVal = "Helllo you stinky lizard" 96 statusCode = http.StatusTeapot 97 responseBody = "go eat a fly ugly\n" 98 done = make(chan struct{}) 99 ) 100 handler := httptransport.NewServer( 101 endpoint.Nop, 102 func(context.Context, *http.Request) (interface{}, error) { 103 return struct{}{}, nil 104 }, 105 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { 106 w.Header().Set(headerKey, headerVal) 107 w.WriteHeader(statusCode) 108 w.Write([]byte(responseBody)) 109 return nil 110 }, 111 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { 112 ctx = context.WithValue(ctx, "one", 1) 113 114 return ctx 115 }), 116 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { 117 if _, ok := ctx.Value("one").(int); !ok { 118 t.Error("Value was not set properly when multiple ServerBefores are used") 119 } 120 121 close(done) 122 return ctx 123 }), 124 ) 125 126 server := httptest.NewServer(handler) 127 defer server.Close() 128 go http.Get(server.URL) 129 130 select { 131 case <-done: 132 case <-time.After(time.Second): 133 t.Fatal("timeout waiting for finalizer") 134 } 135} 136 137func TestMultipleServerAfter(t *testing.T) { 138 var ( 139 headerKey = "X-Henlo-Lizer" 140 headerVal = "Helllo you stinky lizard" 141 statusCode = http.StatusTeapot 142 responseBody = "go eat a fly ugly\n" 143 done = make(chan struct{}) 144 ) 145 handler := httptransport.NewServer( 146 endpoint.Nop, 147 func(context.Context, *http.Request) (interface{}, error) { 148 return struct{}{}, nil 149 }, 150 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { 151 w.Header().Set(headerKey, headerVal) 152 w.WriteHeader(statusCode) 153 w.Write([]byte(responseBody)) 154 return nil 155 }, 156 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { 157 ctx = context.WithValue(ctx, "one", 1) 158 159 return ctx 160 }), 161 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { 162 if _, ok := ctx.Value("one").(int); !ok { 163 t.Error("Value was not set properly when multiple ServerAfters are used") 164 } 165 166 close(done) 167 return ctx 168 }), 169 ) 170 171 server := httptest.NewServer(handler) 172 defer server.Close() 173 go http.Get(server.URL) 174 175 select { 176 case <-done: 177 case <-time.After(time.Second): 178 t.Fatal("timeout waiting for finalizer") 179 } 180} 181 182func TestServerFinalizer(t *testing.T) { 183 var ( 184 headerKey = "X-Henlo-Lizer" 185 headerVal = "Helllo you stinky lizard" 186 statusCode = http.StatusTeapot 187 responseBody = "go eat a fly ugly\n" 188 done = make(chan struct{}) 189 ) 190 handler := httptransport.NewServer( 191 endpoint.Nop, 192 func(context.Context, *http.Request) (interface{}, error) { 193 return struct{}{}, nil 194 }, 195 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { 196 w.Header().Set(headerKey, headerVal) 197 w.WriteHeader(statusCode) 198 w.Write([]byte(responseBody)) 199 return nil 200 }, 201 httptransport.ServerFinalizer(func(ctx context.Context, code int, _ *http.Request) { 202 if want, have := statusCode, code; want != have { 203 t.Errorf("StatusCode: want %d, have %d", want, have) 204 } 205 206 responseHeader := ctx.Value(httptransport.ContextKeyResponseHeaders).(http.Header) 207 if want, have := headerVal, responseHeader.Get(headerKey); want != have { 208 t.Errorf("%s: want %q, have %q", headerKey, want, have) 209 } 210 211 responseSize := ctx.Value(httptransport.ContextKeyResponseSize).(int64) 212 if want, have := int64(len(responseBody)), responseSize; want != have { 213 t.Errorf("response size: want %d, have %d", want, have) 214 } 215 216 close(done) 217 }), 218 ) 219 220 server := httptest.NewServer(handler) 221 defer server.Close() 222 go http.Get(server.URL) 223 224 select { 225 case <-done: 226 case <-time.After(time.Second): 227 t.Fatal("timeout waiting for finalizer") 228 } 229} 230 231type enhancedResponse struct { 232 Foo string `json:"foo"` 233} 234 235func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } 236func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } 237 238func TestEncodeJSONResponse(t *testing.T) { 239 handler := httptransport.NewServer( 240 func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, 241 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 242 httptransport.EncodeJSONResponse, 243 ) 244 245 server := httptest.NewServer(handler) 246 defer server.Close() 247 248 resp, err := http.Get(server.URL) 249 if err != nil { 250 t.Fatal(err) 251 } 252 if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { 253 t.Errorf("StatusCode: want %d, have %d", want, have) 254 } 255 if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { 256 t.Errorf("X-Edward: want %q, have %q", want, have) 257 } 258 buf, _ := ioutil.ReadAll(resp.Body) 259 if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { 260 t.Errorf("Body: want %s, have %s", want, have) 261 } 262} 263 264type multiHeaderResponse struct{} 265 266func (_ multiHeaderResponse) Headers() http.Header { 267 return http.Header{"Vary": []string{"Origin", "User-Agent"}} 268} 269 270func TestAddMultipleHeaders(t *testing.T) { 271 handler := httptransport.NewServer( 272 func(context.Context, interface{}) (interface{}, error) { return multiHeaderResponse{}, nil }, 273 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 274 httptransport.EncodeJSONResponse, 275 ) 276 277 server := httptest.NewServer(handler) 278 defer server.Close() 279 280 resp, err := http.Get(server.URL) 281 if err != nil { 282 t.Fatal(err) 283 } 284 expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} 285 for k, vls := range resp.Header { 286 for _, v := range vls { 287 delete((expect[k]), v) 288 } 289 if len(expect[k]) != 0 { 290 t.Errorf("Header: unexpected header %s: %v", k, expect[k]) 291 } 292 } 293} 294 295type multiHeaderResponseError struct { 296 multiHeaderResponse 297 msg string 298} 299 300func (m multiHeaderResponseError) Error() string { 301 return m.msg 302} 303 304func TestAddMultipleHeadersErrorEncoder(t *testing.T) { 305 errStr := "oh no" 306 handler := httptransport.NewServer( 307 func(context.Context, interface{}) (interface{}, error) { 308 return nil, multiHeaderResponseError{msg: errStr} 309 }, 310 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 311 httptransport.EncodeJSONResponse, 312 ) 313 314 server := httptest.NewServer(handler) 315 defer server.Close() 316 317 resp, err := http.Get(server.URL) 318 if err != nil { 319 t.Fatal(err) 320 } 321 expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}} 322 for k, vls := range resp.Header { 323 for _, v := range vls { 324 delete((expect[k]), v) 325 } 326 if len(expect[k]) != 0 { 327 t.Errorf("Header: unexpected header %s: %v", k, expect[k]) 328 } 329 } 330 if b, _ := ioutil.ReadAll(resp.Body); errStr != string(b) { 331 t.Errorf("ErrorEncoder: got: %q, expected: %q", b, errStr) 332 } 333} 334 335type noContentResponse struct{} 336 337func (e noContentResponse) StatusCode() int { return http.StatusNoContent } 338 339func TestEncodeNoContent(t *testing.T) { 340 handler := httptransport.NewServer( 341 func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, 342 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 343 httptransport.EncodeJSONResponse, 344 ) 345 346 server := httptest.NewServer(handler) 347 defer server.Close() 348 349 resp, err := http.Get(server.URL) 350 if err != nil { 351 t.Fatal(err) 352 } 353 if want, have := http.StatusNoContent, resp.StatusCode; want != have { 354 t.Errorf("StatusCode: want %d, have %d", want, have) 355 } 356 buf, _ := ioutil.ReadAll(resp.Body) 357 if want, have := 0, len(buf); want != have { 358 t.Errorf("Body: want no content, have %d bytes", have) 359 } 360} 361 362type enhancedError struct{} 363 364func (e enhancedError) Error() string { return "enhanced error" } 365func (e enhancedError) StatusCode() int { return http.StatusTeapot } 366func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } 367func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } 368 369func TestEnhancedError(t *testing.T) { 370 handler := httptransport.NewServer( 371 func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, 372 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 373 func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, 374 ) 375 376 server := httptest.NewServer(handler) 377 defer server.Close() 378 379 resp, err := http.Get(server.URL) 380 if err != nil { 381 t.Fatal(err) 382 } 383 defer resp.Body.Close() 384 if want, have := http.StatusTeapot, resp.StatusCode; want != have { 385 t.Errorf("StatusCode: want %d, have %d", want, have) 386 } 387 if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { 388 t.Errorf("X-Enhanced: want %q, have %q", want, have) 389 } 390 buf, _ := ioutil.ReadAll(resp.Body) 391 if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { 392 t.Errorf("Body: want %s, have %s", want, have) 393 } 394} 395 396func TestNoOpRequestDecoder(t *testing.T) { 397 resw := httptest.NewRecorder() 398 req, err := http.NewRequest(http.MethodGet, "/", nil) 399 if err != nil { 400 t.Error("Failed to create request") 401 } 402 handler := httptransport.NewServer( 403 func(ctx context.Context, request interface{}) (interface{}, error) { 404 if request != nil { 405 t.Error("Expected nil request in endpoint when using NopRequestDecoder") 406 } 407 return nil, nil 408 }, 409 httptransport.NopRequestDecoder, 410 httptransport.EncodeJSONResponse, 411 ) 412 handler.ServeHTTP(resw, req) 413 if resw.Code != http.StatusOK { 414 t.Errorf("Expected status code %d but got %d", http.StatusOK, resw.Code) 415 } 416} 417 418func testServer(t *testing.T) (step func(), resp <-chan *http.Response) { 419 var ( 420 stepch = make(chan bool) 421 endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } 422 response = make(chan *http.Response) 423 handler = httptransport.NewServer( 424 endpoint, 425 func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, 426 func(context.Context, http.ResponseWriter, interface{}) error { return nil }, 427 httptransport.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { return ctx }), 428 httptransport.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { return ctx }), 429 ) 430 ) 431 go func() { 432 server := httptest.NewServer(handler) 433 defer server.Close() 434 resp, err := http.Get(server.URL) 435 if err != nil { 436 t.Error(err) 437 return 438 } 439 response <- resp 440 }() 441 return func() { stepch <- true }, response 442} 443