1package handlers 2 3import ( 4 "net/http" 5 "net/http/httptest" 6 "strings" 7 "testing" 8) 9 10func TestDefaultCORSHandlerReturnsOk(t *testing.T) { 11 r := newRequest("GET", "http://www.example.com/") 12 rr := httptest.NewRecorder() 13 14 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 15 16 CORS()(testHandler).ServeHTTP(rr, r) 17 18 if status := rr.Code; status != http.StatusOK { 19 t.Fatalf("bad status: got %v want %v", status, http.StatusFound) 20 } 21} 22 23func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) { 24 r := newRequest("GET", "http://www.example.com/") 25 r.Header.Set("Origin", r.URL.String()) 26 27 rr := httptest.NewRecorder() 28 29 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 30 31 CORS()(testHandler).ServeHTTP(rr, r) 32 33 if status := rr.Code; status != http.StatusOK { 34 t.Fatalf("bad status: got %v want %v", status, http.StatusFound) 35 } 36} 37 38func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) { 39 r := newRequest("OPTIONS", "http://www.example.com/") 40 r.Header.Set("Origin", r.URL.String()) 41 42 rr := httptest.NewRecorder() 43 44 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 45 w.WriteHeader(http.StatusTeapot) 46 }) 47 48 CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r) 49 50 if status := rr.Code; status != http.StatusTeapot { 51 t.Fatalf("bad status: got %v want %v", status, http.StatusTeapot) 52 } 53} 54 55func TestCORSHandlerSetsExposedHeaders(t *testing.T) { 56 // Test default configuration. 57 r := newRequest("GET", "http://www.example.com/") 58 r.Header.Set("Origin", r.URL.String()) 59 60 rr := httptest.NewRecorder() 61 62 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 63 64 CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r) 65 66 if status := rr.Code; status != http.StatusOK { 67 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 68 } 69 70 header := rr.HeaderMap.Get(corsExposeHeadersHeader) 71 if header != "X-Cors-Test" { 72 t.Fatal("bad header: expected X-Cors-Test header, got empty header for method.") 73 } 74} 75 76func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) { 77 r := newRequest("OPTIONS", "http://www.example.com/") 78 r.Header.Set("Origin", r.URL.String()) 79 80 rr := httptest.NewRecorder() 81 82 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 83 84 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 85 86 if status := rr.Code; status != http.StatusBadRequest { 87 t.Fatalf("bad status: got %v want %v", status, http.StatusBadRequest) 88 } 89} 90 91func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) { 92 r := newRequest("OPTIONS", "http://www.example.com/") 93 r.Header.Set("Origin", r.URL.String()) 94 r.Header.Set(corsRequestMethodHeader, "DELETE") 95 96 rr := httptest.NewRecorder() 97 98 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 99 100 CORS()(testHandler).ServeHTTP(rr, r) 101 102 if status := rr.Code; status != http.StatusMethodNotAllowed { 103 t.Fatalf("bad status: got %v want %v", status, http.StatusMethodNotAllowed) 104 } 105} 106 107func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) { 108 r := newRequest("OPTIONS", "http://www.example.com/") 109 r.Header.Set("Origin", r.URL.String()) 110 r.Header.Set(corsRequestMethodHeader, "GET") 111 112 rr := httptest.NewRecorder() 113 114 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 115 t.Fatal("Options request must not be passed to next handler") 116 }) 117 118 CORS()(testHandler).ServeHTTP(rr, r) 119 120 if status := rr.Code; status != http.StatusOK { 121 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 122 } 123} 124 125func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) { 126 statusCode := 204 127 r := newRequest("OPTIONS", "http://www.example.com/") 128 r.Header.Set("Origin", r.URL.String()) 129 r.Header.Set(corsRequestMethodHeader, "GET") 130 131 rr := httptest.NewRecorder() 132 133 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 134 t.Fatal("Options request must not be passed to next handler") 135 }) 136 137 CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r) 138 139 if status := rr.Code; status != statusCode { 140 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 141 } 142} 143 144func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) { 145 r := newRequest("OPTIONS", "http://www.example.com/") 146 r.Header.Set("Origin", r.URL.String()) 147 r.Header.Set(corsRequestMethodHeader, "GET") 148 149 rr := httptest.NewRecorder() 150 151 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 152 t.Fatal("Options request must not be passed to next handler") 153 }) 154 155 CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r) 156 157 if status := rr.Code; status != http.StatusOK { 158 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 159 } 160} 161 162func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) { 163 r := newRequest("OPTIONS", "http://www.example.com/") 164 r.Header.Set("Origin", r.URL.String()) 165 r.Header.Set(corsRequestMethodHeader, "DELETE") 166 167 rr := httptest.NewRecorder() 168 169 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 170 171 CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r) 172 173 if status := rr.Code; status != http.StatusOK { 174 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 175 } 176 177 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 178 if header != "DELETE" { 179 t.Fatalf("bad header: expected DELETE method header, got empty header.") 180 } 181} 182 183func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) { 184 for _, method := range defaultCorsMethods { 185 r := newRequest("OPTIONS", "http://www.example.com/") 186 r.Header.Set("Origin", r.URL.String()) 187 r.Header.Set(corsRequestMethodHeader, method) 188 189 rr := httptest.NewRecorder() 190 191 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 192 193 CORS()(testHandler).ServeHTTP(rr, r) 194 195 if status := rr.Code; status != http.StatusOK { 196 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 197 } 198 199 header := rr.HeaderMap.Get(corsAllowMethodsHeader) 200 if header != "" { 201 t.Fatalf("bad header: expected empty method header, got %s.", header) 202 } 203 } 204} 205 206func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) { 207 for _, simpleHeader := range defaultCorsHeaders { 208 r := newRequest("OPTIONS", "http://www.example.com/") 209 r.Header.Set("Origin", r.URL.String()) 210 r.Header.Set(corsRequestMethodHeader, "GET") 211 r.Header.Set(corsRequestHeadersHeader, simpleHeader) 212 213 rr := httptest.NewRecorder() 214 215 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 216 217 CORS()(testHandler).ServeHTTP(rr, r) 218 219 if status := rr.Code; status != http.StatusOK { 220 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 221 } 222 223 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 224 if header != "" { 225 t.Fatalf("bad header: expected empty header, got %s.", header) 226 } 227 } 228} 229 230func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) { 231 r := newRequest("OPTIONS", "http://www.example.com/") 232 r.Header.Set("Origin", r.URL.String()) 233 r.Header.Set(corsRequestMethodHeader, "POST") 234 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 235 236 rr := httptest.NewRecorder() 237 238 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 239 240 CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r) 241 242 if status := rr.Code; status != http.StatusOK { 243 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 244 } 245 246 header := rr.HeaderMap.Get(corsAllowHeadersHeader) 247 if header != "Content-Type" { 248 t.Fatalf("bad header: expected Content-Type header, got empty header.") 249 } 250} 251 252func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) { 253 r := newRequest("OPTIONS", "http://www.example.com/") 254 r.Header.Set("Origin", r.URL.String()) 255 r.Header.Set(corsRequestMethodHeader, "POST") 256 r.Header.Set(corsRequestHeadersHeader, "Content-Type") 257 258 rr := httptest.NewRecorder() 259 260 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 261 262 CORS()(testHandler).ServeHTTP(rr, r) 263 264 if status := rr.Code; status != http.StatusForbidden { 265 t.Fatalf("bad status: got %v want %v", status, http.StatusForbidden) 266 } 267} 268 269func TestCORSHandlerMaxAgeForPreflight(t *testing.T) { 270 r := newRequest("OPTIONS", "http://www.example.com/") 271 r.Header.Set("Origin", r.URL.String()) 272 r.Header.Set(corsRequestMethodHeader, "POST") 273 274 rr := httptest.NewRecorder() 275 276 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 277 278 CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r) 279 280 if status := rr.Code; status != http.StatusOK { 281 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 282 } 283 284 header := rr.HeaderMap.Get(corsMaxAgeHeader) 285 if header != "600" { 286 t.Fatalf("bad header: expected %s to be %s, got %s.", corsMaxAgeHeader, "600", header) 287 } 288} 289 290func TestCORSHandlerAllowedCredentials(t *testing.T) { 291 r := newRequest("GET", "http://www.example.com/") 292 r.Header.Set("Origin", r.URL.String()) 293 294 rr := httptest.NewRecorder() 295 296 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 297 298 CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r) 299 300 if status := rr.Code; status != http.StatusOK { 301 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 302 } 303 304 header := rr.HeaderMap.Get(corsAllowCredentialsHeader) 305 if header != "true" { 306 t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowCredentialsHeader, "true", header) 307 } 308} 309 310func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) { 311 r := newRequest("GET", "http://www.example.com/") 312 r.Header.Set("Origin", r.URL.String()) 313 314 rr := httptest.NewRecorder() 315 316 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 317 318 CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r) 319 320 if status := rr.Code; status != http.StatusOK { 321 t.Fatalf("bad status: got %v want %v", status, http.StatusOK) 322 } 323 324 header := rr.HeaderMap.Get(corsVaryHeader) 325 if header != corsOriginHeader { 326 t.Fatalf("bad header: expected %s to be %s, got %s.", corsVaryHeader, corsOriginHeader, header) 327 } 328} 329 330func TestCORSWithMultipleHandlers(t *testing.T) { 331 var lastHandledBy string 332 corsMiddleware := CORS() 333 334 testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 335 lastHandledBy = "testHandler1" 336 }) 337 testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 338 lastHandledBy = "testHandler2" 339 }) 340 341 r1 := newRequest("GET", "http://www.example.com/") 342 rr1 := httptest.NewRecorder() 343 handler1 := corsMiddleware(testHandler1) 344 345 corsMiddleware(testHandler2) 346 347 handler1.ServeHTTP(rr1, r1) 348 if lastHandledBy != "testHandler1" { 349 t.Fatalf("bad CORS() registration: Handler served should be Handler registered") 350 } 351} 352 353func TestCORSOriginValidatorWithImplicitStar(t *testing.T) { 354 r := newRequest("GET", "http://a.example.com") 355 r.Header.Set("Origin", r.URL.String()) 356 rr := httptest.NewRecorder() 357 358 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 359 360 originValidator := func(origin string) bool { 361 if strings.HasSuffix(origin, ".example.com") { 362 return true 363 } 364 return false 365 } 366 367 CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r) 368 header := rr.HeaderMap.Get(corsAllowOriginHeader) 369 if header != r.URL.String() { 370 t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header) 371 } 372} 373 374func TestCORSOriginValidatorWithExplicitStar(t *testing.T) { 375 r := newRequest("GET", "http://a.example.com") 376 r.Header.Set("Origin", r.URL.String()) 377 rr := httptest.NewRecorder() 378 379 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 380 381 originValidator := func(origin string) bool { 382 if strings.HasSuffix(origin, ".example.com") { 383 return true 384 } 385 return false 386 } 387 388 CORS( 389 AllowedOriginValidator(originValidator), 390 AllowedOrigins([]string{"*"}), 391 )(testHandler).ServeHTTP(rr, r) 392 header := rr.HeaderMap.Get(corsAllowOriginHeader) 393 if header != "*" { 394 t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header) 395 } 396} 397 398func TestCORSAllowStar(t *testing.T) { 399 r := newRequest("GET", "http://a.example.com") 400 r.Header.Set("Origin", r.URL.String()) 401 rr := httptest.NewRecorder() 402 403 testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 404 405 CORS()(testHandler).ServeHTTP(rr, r) 406 header := rr.HeaderMap.Get(corsAllowOriginHeader) 407 if header != "*" { 408 t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header) 409 } 410} 411