1package handlers
2
3import (
4	"net/http"
5	"net/http/httptest"
6	"strings"
7	"testing"
8)
9
10func TestDefaultCORSHandlerReturnsOk(t *testing.T) {
11	r := newRequest("GET", "http://www.example.com/")
12	rr := httptest.NewRecorder()
13
14	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
15
16	CORS()(testHandler).ServeHTTP(rr, r)
17
18	if status := rr.Code; status != http.StatusOK {
19		t.Fatalf("bad status: got %v want %v", status, http.StatusFound)
20	}
21}
22
23func TestDefaultCORSHandlerReturnsOkWithOrigin(t *testing.T) {
24	r := newRequest("GET", "http://www.example.com/")
25	r.Header.Set("Origin", r.URL.String())
26
27	rr := httptest.NewRecorder()
28
29	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
30
31	CORS()(testHandler).ServeHTTP(rr, r)
32
33	if status := rr.Code; status != http.StatusOK {
34		t.Fatalf("bad status: got %v want %v", status, http.StatusFound)
35	}
36}
37
38func TestCORSHandlerIgnoreOptionsFallsThrough(t *testing.T) {
39	r := newRequest("OPTIONS", "http://www.example.com/")
40	r.Header.Set("Origin", r.URL.String())
41
42	rr := httptest.NewRecorder()
43
44	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
45		w.WriteHeader(http.StatusTeapot)
46	})
47
48	CORS(IgnoreOptions())(testHandler).ServeHTTP(rr, r)
49
50	if status := rr.Code; status != http.StatusTeapot {
51		t.Fatalf("bad status: got %v want %v", status, http.StatusTeapot)
52	}
53}
54
55func TestCORSHandlerSetsExposedHeaders(t *testing.T) {
56	// Test default configuration.
57	r := newRequest("GET", "http://www.example.com/")
58	r.Header.Set("Origin", r.URL.String())
59
60	rr := httptest.NewRecorder()
61
62	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
63
64	CORS(ExposedHeaders([]string{"X-CORS-TEST"}))(testHandler).ServeHTTP(rr, r)
65
66	if status := rr.Code; status != http.StatusOK {
67		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
68	}
69
70	header := rr.HeaderMap.Get(corsExposeHeadersHeader)
71	if header != "X-Cors-Test" {
72		t.Fatal("bad header: expected X-Cors-Test header, got empty header for method.")
73	}
74}
75
76func TestCORSHandlerUnsetRequestMethodForPreflightBadRequest(t *testing.T) {
77	r := newRequest("OPTIONS", "http://www.example.com/")
78	r.Header.Set("Origin", r.URL.String())
79
80	rr := httptest.NewRecorder()
81
82	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
83
84	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
85
86	if status := rr.Code; status != http.StatusBadRequest {
87		t.Fatalf("bad status: got %v want %v", status, http.StatusBadRequest)
88	}
89}
90
91func TestCORSHandlerInvalidRequestMethodForPreflightMethodNotAllowed(t *testing.T) {
92	r := newRequest("OPTIONS", "http://www.example.com/")
93	r.Header.Set("Origin", r.URL.String())
94	r.Header.Set(corsRequestMethodHeader, "DELETE")
95
96	rr := httptest.NewRecorder()
97
98	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
99
100	CORS()(testHandler).ServeHTTP(rr, r)
101
102	if status := rr.Code; status != http.StatusMethodNotAllowed {
103		t.Fatalf("bad status: got %v want %v", status, http.StatusMethodNotAllowed)
104	}
105}
106
107func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandler(t *testing.T) {
108	r := newRequest("OPTIONS", "http://www.example.com/")
109	r.Header.Set("Origin", r.URL.String())
110	r.Header.Set(corsRequestMethodHeader, "GET")
111
112	rr := httptest.NewRecorder()
113
114	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115		t.Fatal("Options request must not be passed to next handler")
116	})
117
118	CORS()(testHandler).ServeHTTP(rr, r)
119
120	if status := rr.Code; status != http.StatusOK {
121		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
122	}
123}
124
125func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWithCustomStatusCode(t *testing.T) {
126	statusCode := 204
127	r := newRequest("OPTIONS", "http://www.example.com/")
128	r.Header.Set("Origin", r.URL.String())
129	r.Header.Set(corsRequestMethodHeader, "GET")
130
131	rr := httptest.NewRecorder()
132
133	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
134		t.Fatal("Options request must not be passed to next handler")
135	})
136
137	CORS(OptionStatusCode(statusCode))(testHandler).ServeHTTP(rr, r)
138
139	if status := rr.Code; status != statusCode {
140		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
141	}
142}
143
144func TestCORSHandlerOptionsRequestMustNotBePassedToNextHandlerWhenOriginNotAllowed(t *testing.T) {
145	r := newRequest("OPTIONS", "http://www.example.com/")
146	r.Header.Set("Origin", r.URL.String())
147	r.Header.Set(corsRequestMethodHeader, "GET")
148
149	rr := httptest.NewRecorder()
150
151	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
152		t.Fatal("Options request must not be passed to next handler")
153	})
154
155	CORS(AllowedOrigins([]string{}))(testHandler).ServeHTTP(rr, r)
156
157	if status := rr.Code; status != http.StatusOK {
158		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
159	}
160}
161
162func TestCORSHandlerAllowedMethodForPreflight(t *testing.T) {
163	r := newRequest("OPTIONS", "http://www.example.com/")
164	r.Header.Set("Origin", r.URL.String())
165	r.Header.Set(corsRequestMethodHeader, "DELETE")
166
167	rr := httptest.NewRecorder()
168
169	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
170
171	CORS(AllowedMethods([]string{"DELETE"}))(testHandler).ServeHTTP(rr, r)
172
173	if status := rr.Code; status != http.StatusOK {
174		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
175	}
176
177	header := rr.HeaderMap.Get(corsAllowMethodsHeader)
178	if header != "DELETE" {
179		t.Fatalf("bad header: expected DELETE method header, got empty header.")
180	}
181}
182
183func TestCORSHandlerAllowMethodsNotSetForSimpleRequestPreflight(t *testing.T) {
184	for _, method := range defaultCorsMethods {
185		r := newRequest("OPTIONS", "http://www.example.com/")
186		r.Header.Set("Origin", r.URL.String())
187		r.Header.Set(corsRequestMethodHeader, method)
188
189		rr := httptest.NewRecorder()
190
191		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
192
193		CORS()(testHandler).ServeHTTP(rr, r)
194
195		if status := rr.Code; status != http.StatusOK {
196			t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
197		}
198
199		header := rr.HeaderMap.Get(corsAllowMethodsHeader)
200		if header != "" {
201			t.Fatalf("bad header: expected empty method header, got %s.", header)
202		}
203	}
204}
205
206func TestCORSHandlerAllowedHeaderNotSetForSimpleRequestPreflight(t *testing.T) {
207	for _, simpleHeader := range defaultCorsHeaders {
208		r := newRequest("OPTIONS", "http://www.example.com/")
209		r.Header.Set("Origin", r.URL.String())
210		r.Header.Set(corsRequestMethodHeader, "GET")
211		r.Header.Set(corsRequestHeadersHeader, simpleHeader)
212
213		rr := httptest.NewRecorder()
214
215		testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
216
217		CORS()(testHandler).ServeHTTP(rr, r)
218
219		if status := rr.Code; status != http.StatusOK {
220			t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
221		}
222
223		header := rr.HeaderMap.Get(corsAllowHeadersHeader)
224		if header != "" {
225			t.Fatalf("bad header: expected empty header, got %s.", header)
226		}
227	}
228}
229
230func TestCORSHandlerAllowedHeaderForPreflight(t *testing.T) {
231	r := newRequest("OPTIONS", "http://www.example.com/")
232	r.Header.Set("Origin", r.URL.String())
233	r.Header.Set(corsRequestMethodHeader, "POST")
234	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
235
236	rr := httptest.NewRecorder()
237
238	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
239
240	CORS(AllowedHeaders([]string{"Content-Type"}))(testHandler).ServeHTTP(rr, r)
241
242	if status := rr.Code; status != http.StatusOK {
243		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
244	}
245
246	header := rr.HeaderMap.Get(corsAllowHeadersHeader)
247	if header != "Content-Type" {
248		t.Fatalf("bad header: expected Content-Type header, got empty header.")
249	}
250}
251
252func TestCORSHandlerInvalidHeaderForPreflightForbidden(t *testing.T) {
253	r := newRequest("OPTIONS", "http://www.example.com/")
254	r.Header.Set("Origin", r.URL.String())
255	r.Header.Set(corsRequestMethodHeader, "POST")
256	r.Header.Set(corsRequestHeadersHeader, "Content-Type")
257
258	rr := httptest.NewRecorder()
259
260	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
261
262	CORS()(testHandler).ServeHTTP(rr, r)
263
264	if status := rr.Code; status != http.StatusForbidden {
265		t.Fatalf("bad status: got %v want %v", status, http.StatusForbidden)
266	}
267}
268
269func TestCORSHandlerMaxAgeForPreflight(t *testing.T) {
270	r := newRequest("OPTIONS", "http://www.example.com/")
271	r.Header.Set("Origin", r.URL.String())
272	r.Header.Set(corsRequestMethodHeader, "POST")
273
274	rr := httptest.NewRecorder()
275
276	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
277
278	CORS(MaxAge(3500))(testHandler).ServeHTTP(rr, r)
279
280	if status := rr.Code; status != http.StatusOK {
281		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
282	}
283
284	header := rr.HeaderMap.Get(corsMaxAgeHeader)
285	if header != "600" {
286		t.Fatalf("bad header: expected %s to be %s, got %s.", corsMaxAgeHeader, "600", header)
287	}
288}
289
290func TestCORSHandlerAllowedCredentials(t *testing.T) {
291	r := newRequest("GET", "http://www.example.com/")
292	r.Header.Set("Origin", r.URL.String())
293
294	rr := httptest.NewRecorder()
295
296	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
297
298	CORS(AllowCredentials())(testHandler).ServeHTTP(rr, r)
299
300	if status := rr.Code; status != http.StatusOK {
301		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
302	}
303
304	header := rr.HeaderMap.Get(corsAllowCredentialsHeader)
305	if header != "true" {
306		t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowCredentialsHeader, "true", header)
307	}
308}
309
310func TestCORSHandlerMultipleAllowOriginsSetsVaryHeader(t *testing.T) {
311	r := newRequest("GET", "http://www.example.com/")
312	r.Header.Set("Origin", r.URL.String())
313
314	rr := httptest.NewRecorder()
315
316	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
317
318	CORS(AllowedOrigins([]string{r.URL.String(), "http://google.com"}))(testHandler).ServeHTTP(rr, r)
319
320	if status := rr.Code; status != http.StatusOK {
321		t.Fatalf("bad status: got %v want %v", status, http.StatusOK)
322	}
323
324	header := rr.HeaderMap.Get(corsVaryHeader)
325	if header != corsOriginHeader {
326		t.Fatalf("bad header: expected %s to be %s, got %s.", corsVaryHeader, corsOriginHeader, header)
327	}
328}
329
330func TestCORSWithMultipleHandlers(t *testing.T) {
331	var lastHandledBy string
332	corsMiddleware := CORS()
333
334	testHandler1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
335		lastHandledBy = "testHandler1"
336	})
337	testHandler2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
338		lastHandledBy = "testHandler2"
339	})
340
341	r1 := newRequest("GET", "http://www.example.com/")
342	rr1 := httptest.NewRecorder()
343	handler1 := corsMiddleware(testHandler1)
344
345	corsMiddleware(testHandler2)
346
347	handler1.ServeHTTP(rr1, r1)
348	if lastHandledBy != "testHandler1" {
349		t.Fatalf("bad CORS() registration: Handler served should be Handler registered")
350	}
351}
352
353func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
354	r := newRequest("GET", "http://a.example.com")
355	r.Header.Set("Origin", r.URL.String())
356	rr := httptest.NewRecorder()
357
358	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
359
360	originValidator := func(origin string) bool {
361		if strings.HasSuffix(origin, ".example.com") {
362			return true
363		}
364		return false
365	}
366
367	CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
368	header := rr.HeaderMap.Get(corsAllowOriginHeader)
369	if header != r.URL.String() {
370		t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header)
371	}
372}
373
374func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
375	r := newRequest("GET", "http://a.example.com")
376	r.Header.Set("Origin", r.URL.String())
377	rr := httptest.NewRecorder()
378
379	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
380
381	originValidator := func(origin string) bool {
382		if strings.HasSuffix(origin, ".example.com") {
383			return true
384		}
385		return false
386	}
387
388	CORS(
389		AllowedOriginValidator(originValidator),
390		AllowedOrigins([]string{"*"}),
391	)(testHandler).ServeHTTP(rr, r)
392	header := rr.HeaderMap.Get(corsAllowOriginHeader)
393	if header != "*" {
394		t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
395	}
396}
397
398func TestCORSAllowStar(t *testing.T) {
399	r := newRequest("GET", "http://a.example.com")
400	r.Header.Set("Origin", r.URL.String())
401	rr := httptest.NewRecorder()
402
403	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
404
405	CORS()(testHandler).ServeHTTP(rr, r)
406	header := rr.HeaderMap.Get(corsAllowOriginHeader)
407	if header != "*" {
408		t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
409	}
410}
411