1// Copyright 2013 Julien Schmidt. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be found
3// in the LICENSE file.
4
5package httprouter
6
7import (
8	"errors"
9	"fmt"
10	"net/http"
11	"net/http/httptest"
12	"reflect"
13	"testing"
14)
15
16type mockResponseWriter struct{}
17
18func (m *mockResponseWriter) Header() (h http.Header) {
19	return http.Header{}
20}
21
22func (m *mockResponseWriter) Write(p []byte) (n int, err error) {
23	return len(p), nil
24}
25
26func (m *mockResponseWriter) WriteString(s string) (n int, err error) {
27	return len(s), nil
28}
29
30func (m *mockResponseWriter) WriteHeader(int) {}
31
32func TestParams(t *testing.T) {
33	ps := Params{
34		Param{"param1", "value1"},
35		Param{"param2", "value2"},
36		Param{"param3", "value3"},
37	}
38	for i := range ps {
39		if val := ps.ByName(ps[i].Key); val != ps[i].Value {
40			t.Errorf("Wrong value for %s: Got %s; Want %s", ps[i].Key, val, ps[i].Value)
41		}
42	}
43	if val := ps.ByName("noKey"); val != "" {
44		t.Errorf("Expected empty string for not found key; got: %s", val)
45	}
46}
47
48func TestRouter(t *testing.T) {
49	router := New()
50
51	routed := false
52	router.Handle(http.MethodGet, "/user/:name", func(w http.ResponseWriter, r *http.Request, ps Params) {
53		routed = true
54		want := Params{Param{"name", "gopher"}}
55		if !reflect.DeepEqual(ps, want) {
56			t.Fatalf("wrong wildcard values: want %v, got %v", want, ps)
57		}
58	})
59
60	w := new(mockResponseWriter)
61
62	req, _ := http.NewRequest(http.MethodGet, "/user/gopher", nil)
63	router.ServeHTTP(w, req)
64
65	if !routed {
66		t.Fatal("routing failed")
67	}
68}
69
70type handlerStruct struct {
71	handled *bool
72}
73
74func (h handlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
75	*h.handled = true
76}
77
78func TestRouterAPI(t *testing.T) {
79	var get, head, options, post, put, patch, delete, handler, handlerFunc bool
80
81	httpHandler := handlerStruct{&handler}
82
83	router := New()
84	router.GET("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
85		get = true
86	})
87	router.HEAD("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
88		head = true
89	})
90	router.OPTIONS("/GET", func(w http.ResponseWriter, r *http.Request, _ Params) {
91		options = true
92	})
93	router.POST("/POST", func(w http.ResponseWriter, r *http.Request, _ Params) {
94		post = true
95	})
96	router.PUT("/PUT", func(w http.ResponseWriter, r *http.Request, _ Params) {
97		put = true
98	})
99	router.PATCH("/PATCH", func(w http.ResponseWriter, r *http.Request, _ Params) {
100		patch = true
101	})
102	router.DELETE("/DELETE", func(w http.ResponseWriter, r *http.Request, _ Params) {
103		delete = true
104	})
105	router.Handler(http.MethodGet, "/Handler", httpHandler)
106	router.HandlerFunc(http.MethodGet, "/HandlerFunc", func(w http.ResponseWriter, r *http.Request) {
107		handlerFunc = true
108	})
109
110	w := new(mockResponseWriter)
111
112	r, _ := http.NewRequest(http.MethodGet, "/GET", nil)
113	router.ServeHTTP(w, r)
114	if !get {
115		t.Error("routing GET failed")
116	}
117
118	r, _ = http.NewRequest(http.MethodHead, "/GET", nil)
119	router.ServeHTTP(w, r)
120	if !head {
121		t.Error("routing HEAD failed")
122	}
123
124	r, _ = http.NewRequest(http.MethodOptions, "/GET", nil)
125	router.ServeHTTP(w, r)
126	if !options {
127		t.Error("routing OPTIONS failed")
128	}
129
130	r, _ = http.NewRequest(http.MethodPost, "/POST", nil)
131	router.ServeHTTP(w, r)
132	if !post {
133		t.Error("routing POST failed")
134	}
135
136	r, _ = http.NewRequest(http.MethodPut, "/PUT", nil)
137	router.ServeHTTP(w, r)
138	if !put {
139		t.Error("routing PUT failed")
140	}
141
142	r, _ = http.NewRequest(http.MethodPatch, "/PATCH", nil)
143	router.ServeHTTP(w, r)
144	if !patch {
145		t.Error("routing PATCH failed")
146	}
147
148	r, _ = http.NewRequest(http.MethodDelete, "/DELETE", nil)
149	router.ServeHTTP(w, r)
150	if !delete {
151		t.Error("routing DELETE failed")
152	}
153
154	r, _ = http.NewRequest(http.MethodGet, "/Handler", nil)
155	router.ServeHTTP(w, r)
156	if !handler {
157		t.Error("routing Handler failed")
158	}
159
160	r, _ = http.NewRequest(http.MethodGet, "/HandlerFunc", nil)
161	router.ServeHTTP(w, r)
162	if !handlerFunc {
163		t.Error("routing HandlerFunc failed")
164	}
165}
166
167func TestRouterRoot(t *testing.T) {
168	router := New()
169	recv := catchPanic(func() {
170		router.GET("noSlashRoot", nil)
171	})
172	if recv == nil {
173		t.Fatal("registering path not beginning with '/' did not panic")
174	}
175}
176
177func TestRouterChaining(t *testing.T) {
178	router1 := New()
179	router2 := New()
180	router1.NotFound = router2
181
182	fooHit := false
183	router1.POST("/foo", func(w http.ResponseWriter, req *http.Request, _ Params) {
184		fooHit = true
185		w.WriteHeader(http.StatusOK)
186	})
187
188	barHit := false
189	router2.POST("/bar", func(w http.ResponseWriter, req *http.Request, _ Params) {
190		barHit = true
191		w.WriteHeader(http.StatusOK)
192	})
193
194	r, _ := http.NewRequest(http.MethodPost, "/foo", nil)
195	w := httptest.NewRecorder()
196	router1.ServeHTTP(w, r)
197	if !(w.Code == http.StatusOK && fooHit) {
198		t.Errorf("Regular routing failed with router chaining.")
199		t.FailNow()
200	}
201
202	r, _ = http.NewRequest(http.MethodPost, "/bar", nil)
203	w = httptest.NewRecorder()
204	router1.ServeHTTP(w, r)
205	if !(w.Code == http.StatusOK && barHit) {
206		t.Errorf("Chained routing failed with router chaining.")
207		t.FailNow()
208	}
209
210	r, _ = http.NewRequest(http.MethodPost, "/qax", nil)
211	w = httptest.NewRecorder()
212	router1.ServeHTTP(w, r)
213	if !(w.Code == http.StatusNotFound) {
214		t.Errorf("NotFound behavior failed with router chaining.")
215		t.FailNow()
216	}
217}
218
219func BenchmarkAllowed(b *testing.B) {
220	handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
221
222	router := New()
223	router.POST("/path", handlerFunc)
224	router.GET("/path", handlerFunc)
225
226	b.Run("Global", func(b *testing.B) {
227		b.ReportAllocs()
228		for i := 0; i < b.N; i++ {
229			_ = router.allowed("*", http.MethodOptions)
230		}
231	})
232	b.Run("Path", func(b *testing.B) {
233		b.ReportAllocs()
234		for i := 0; i < b.N; i++ {
235			_ = router.allowed("/path", http.MethodOptions)
236		}
237	})
238}
239
240func TestRouterOPTIONS(t *testing.T) {
241	handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
242
243	router := New()
244	router.POST("/path", handlerFunc)
245
246	// test not allowed
247	// * (server)
248	r, _ := http.NewRequest(http.MethodOptions, "*", nil)
249	w := httptest.NewRecorder()
250	router.ServeHTTP(w, r)
251	if !(w.Code == http.StatusOK) {
252		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
253	} else if allow := w.Header().Get("Allow"); allow != "OPTIONS, POST" {
254		t.Error("unexpected Allow header value: " + allow)
255	}
256
257	// path
258	r, _ = http.NewRequest(http.MethodOptions, "/path", nil)
259	w = httptest.NewRecorder()
260	router.ServeHTTP(w, r)
261	if !(w.Code == http.StatusOK) {
262		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
263	} else if allow := w.Header().Get("Allow"); allow != "OPTIONS, POST" {
264		t.Error("unexpected Allow header value: " + allow)
265	}
266
267	r, _ = http.NewRequest(http.MethodOptions, "/doesnotexist", nil)
268	w = httptest.NewRecorder()
269	router.ServeHTTP(w, r)
270	if !(w.Code == http.StatusNotFound) {
271		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
272	}
273
274	// add another method
275	router.GET("/path", handlerFunc)
276
277	// set a global OPTIONS handler
278	router.GlobalOPTIONS = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
279		// Adjust status code to 204
280		w.WriteHeader(http.StatusNoContent)
281	})
282
283	// test again
284	// * (server)
285	r, _ = http.NewRequest(http.MethodOptions, "*", nil)
286	w = httptest.NewRecorder()
287	router.ServeHTTP(w, r)
288	if !(w.Code == http.StatusNoContent) {
289		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
290	} else if allow := w.Header().Get("Allow"); allow != "GET, OPTIONS, POST" {
291		t.Error("unexpected Allow header value: " + allow)
292	}
293
294	// path
295	r, _ = http.NewRequest(http.MethodOptions, "/path", nil)
296	w = httptest.NewRecorder()
297	router.ServeHTTP(w, r)
298	if !(w.Code == http.StatusNoContent) {
299		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
300	} else if allow := w.Header().Get("Allow"); allow != "GET, OPTIONS, POST" {
301		t.Error("unexpected Allow header value: " + allow)
302	}
303
304	// custom handler
305	var custom bool
306	router.OPTIONS("/path", func(w http.ResponseWriter, r *http.Request, _ Params) {
307		custom = true
308	})
309
310	// test again
311	// * (server)
312	r, _ = http.NewRequest(http.MethodOptions, "*", nil)
313	w = httptest.NewRecorder()
314	router.ServeHTTP(w, r)
315	if !(w.Code == http.StatusNoContent) {
316		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
317	} else if allow := w.Header().Get("Allow"); allow != "GET, OPTIONS, POST" {
318		t.Error("unexpected Allow header value: " + allow)
319	}
320	if custom {
321		t.Error("custom handler called on *")
322	}
323
324	// path
325	r, _ = http.NewRequest(http.MethodOptions, "/path", nil)
326	w = httptest.NewRecorder()
327	router.ServeHTTP(w, r)
328	if !(w.Code == http.StatusOK) {
329		t.Errorf("OPTIONS handling failed: Code=%d, Header=%v", w.Code, w.Header())
330	}
331	if !custom {
332		t.Error("custom handler not called")
333	}
334}
335
336func TestRouterNotAllowed(t *testing.T) {
337	handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
338
339	router := New()
340	router.POST("/path", handlerFunc)
341
342	// test not allowed
343	r, _ := http.NewRequest(http.MethodGet, "/path", nil)
344	w := httptest.NewRecorder()
345	router.ServeHTTP(w, r)
346	if !(w.Code == http.StatusMethodNotAllowed) {
347		t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
348	} else if allow := w.Header().Get("Allow"); allow != "OPTIONS, POST" {
349		t.Error("unexpected Allow header value: " + allow)
350	}
351
352	// add another method
353	router.DELETE("/path", handlerFunc)
354	router.OPTIONS("/path", handlerFunc) // must be ignored
355
356	// test again
357	r, _ = http.NewRequest(http.MethodGet, "/path", nil)
358	w = httptest.NewRecorder()
359	router.ServeHTTP(w, r)
360	if !(w.Code == http.StatusMethodNotAllowed) {
361		t.Errorf("NotAllowed handling failed: Code=%d, Header=%v", w.Code, w.Header())
362	} else if allow := w.Header().Get("Allow"); allow != "DELETE, OPTIONS, POST" {
363		t.Error("unexpected Allow header value: " + allow)
364	}
365
366	// test custom handler
367	w = httptest.NewRecorder()
368	responseText := "custom method"
369	router.MethodNotAllowed = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
370		w.WriteHeader(http.StatusTeapot)
371		w.Write([]byte(responseText))
372	})
373	router.ServeHTTP(w, r)
374	if got := w.Body.String(); !(got == responseText) {
375		t.Errorf("unexpected response got %q want %q", got, responseText)
376	}
377	if w.Code != http.StatusTeapot {
378		t.Errorf("unexpected response code %d want %d", w.Code, http.StatusTeapot)
379	}
380	if allow := w.Header().Get("Allow"); allow != "DELETE, OPTIONS, POST" {
381		t.Error("unexpected Allow header value: " + allow)
382	}
383}
384
385func TestRouterNotFound(t *testing.T) {
386	handlerFunc := func(_ http.ResponseWriter, _ *http.Request, _ Params) {}
387
388	router := New()
389	router.GET("/path", handlerFunc)
390	router.GET("/dir/", handlerFunc)
391	router.GET("/", handlerFunc)
392
393	testRoutes := []struct {
394		route    string
395		code     int
396		location string
397	}{
398		{"/path/", 301, "/path"},   // TSR -/
399		{"/dir", 301, "/dir/"},     // TSR +/
400		{"", 301, "/"},             // TSR +/
401		{"/PATH", 301, "/path"},    // Fixed Case
402		{"/DIR/", 301, "/dir/"},    // Fixed Case
403		{"/PATH/", 301, "/path"},   // Fixed Case -/
404		{"/DIR", 301, "/dir/"},     // Fixed Case +/
405		{"/../path", 301, "/path"}, // CleanPath
406		{"/nope", 404, ""},         // NotFound
407	}
408	for _, tr := range testRoutes {
409		r, _ := http.NewRequest(http.MethodGet, tr.route, nil)
410		w := httptest.NewRecorder()
411		router.ServeHTTP(w, r)
412		if !(w.Code == tr.code && (w.Code == 404 || fmt.Sprint(w.Header().Get("Location")) == tr.location)) {
413			t.Errorf("NotFound handling route %s failed: Code=%d, Header=%v", tr.route, w.Code, w.Header().Get("Location"))
414		}
415	}
416
417	// Test custom not found handler
418	var notFound bool
419	router.NotFound = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
420		rw.WriteHeader(404)
421		notFound = true
422	})
423	r, _ := http.NewRequest(http.MethodGet, "/nope", nil)
424	w := httptest.NewRecorder()
425	router.ServeHTTP(w, r)
426	if !(w.Code == 404 && notFound == true) {
427		t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
428	}
429
430	// Test other method than GET (want 307 instead of 301)
431	router.PATCH("/path", handlerFunc)
432	r, _ = http.NewRequest(http.MethodPatch, "/path/", nil)
433	w = httptest.NewRecorder()
434	router.ServeHTTP(w, r)
435	if !(w.Code == 307 && fmt.Sprint(w.Header()) == "map[Location:[/path]]") {
436		t.Errorf("Custom NotFound handler failed: Code=%d, Header=%v", w.Code, w.Header())
437	}
438
439	// Test special case where no node for the prefix "/" exists
440	router = New()
441	router.GET("/a", handlerFunc)
442	r, _ = http.NewRequest(http.MethodGet, "/", nil)
443	w = httptest.NewRecorder()
444	router.ServeHTTP(w, r)
445	if !(w.Code == 404) {
446		t.Errorf("NotFound handling route / failed: Code=%d", w.Code)
447	}
448}
449
450func TestRouterPanicHandler(t *testing.T) {
451	router := New()
452	panicHandled := false
453
454	router.PanicHandler = func(rw http.ResponseWriter, r *http.Request, p interface{}) {
455		panicHandled = true
456	}
457
458	router.Handle(http.MethodPut, "/user/:name", func(_ http.ResponseWriter, _ *http.Request, _ Params) {
459		panic("oops!")
460	})
461
462	w := new(mockResponseWriter)
463	req, _ := http.NewRequest(http.MethodPut, "/user/gopher", nil)
464
465	defer func() {
466		if rcv := recover(); rcv != nil {
467			t.Fatal("handling panic failed")
468		}
469	}()
470
471	router.ServeHTTP(w, req)
472
473	if !panicHandled {
474		t.Fatal("simulating failed")
475	}
476}
477
478func TestRouterLookup(t *testing.T) {
479	routed := false
480	wantHandle := func(_ http.ResponseWriter, _ *http.Request, _ Params) {
481		routed = true
482	}
483	wantParams := Params{Param{"name", "gopher"}}
484
485	router := New()
486
487	// try empty router first
488	handle, _, tsr := router.Lookup(http.MethodGet, "/nope")
489	if handle != nil {
490		t.Fatalf("Got handle for unregistered pattern: %v", handle)
491	}
492	if tsr {
493		t.Error("Got wrong TSR recommendation!")
494	}
495
496	// insert route and try again
497	router.GET("/user/:name", wantHandle)
498
499	handle, params, _ := router.Lookup(http.MethodGet, "/user/gopher")
500	if handle == nil {
501		t.Fatal("Got no handle!")
502	} else {
503		handle(nil, nil, nil)
504		if !routed {
505			t.Fatal("Routing failed!")
506		}
507	}
508
509	if !reflect.DeepEqual(params, wantParams) {
510		t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
511	}
512
513	handle, _, tsr = router.Lookup(http.MethodGet, "/user/gopher/")
514	if handle != nil {
515		t.Fatalf("Got handle for unregistered pattern: %v", handle)
516	}
517	if !tsr {
518		t.Error("Got no TSR recommendation!")
519	}
520
521	handle, _, tsr = router.Lookup(http.MethodGet, "/nope")
522	if handle != nil {
523		t.Fatalf("Got handle for unregistered pattern: %v", handle)
524	}
525	if tsr {
526		t.Error("Got wrong TSR recommendation!")
527	}
528}
529
530func TestRouterParamsFromContext(t *testing.T) {
531	routed := false
532
533	wantParams := Params{Param{"name", "gopher"}}
534	handlerFunc := func(_ http.ResponseWriter, req *http.Request) {
535		// get params from request context
536		params := ParamsFromContext(req.Context())
537
538		if !reflect.DeepEqual(params, wantParams) {
539			t.Fatalf("Wrong parameter values: want %v, got %v", wantParams, params)
540		}
541
542		routed = true
543	}
544
545	var nilParams Params
546	handlerFuncNil := func(_ http.ResponseWriter, req *http.Request) {
547		// get params from request context
548		params := ParamsFromContext(req.Context())
549
550		if !reflect.DeepEqual(params, nilParams) {
551			t.Fatalf("Wrong parameter values: want %v, got %v", nilParams, params)
552		}
553
554		routed = true
555	}
556	router := New()
557	router.HandlerFunc(http.MethodGet, "/user", handlerFuncNil)
558	router.HandlerFunc(http.MethodGet, "/user/:name", handlerFunc)
559
560	w := new(mockResponseWriter)
561	r, _ := http.NewRequest(http.MethodGet, "/user/gopher", nil)
562	router.ServeHTTP(w, r)
563	if !routed {
564		t.Fatal("Routing failed!")
565	}
566
567	routed = false
568	r, _ = http.NewRequest(http.MethodGet, "/user", nil)
569	router.ServeHTTP(w, r)
570	if !routed {
571		t.Fatal("Routing failed!")
572	}
573}
574
575type mockFileSystem struct {
576	opened bool
577}
578
579func (mfs *mockFileSystem) Open(name string) (http.File, error) {
580	mfs.opened = true
581	return nil, errors.New("this is just a mock")
582}
583
584func TestRouterServeFiles(t *testing.T) {
585	router := New()
586	mfs := &mockFileSystem{}
587
588	recv := catchPanic(func() {
589		router.ServeFiles("/noFilepath", mfs)
590	})
591	if recv == nil {
592		t.Fatal("registering path not ending with '*filepath' did not panic")
593	}
594
595	router.ServeFiles("/*filepath", mfs)
596	w := new(mockResponseWriter)
597	r, _ := http.NewRequest(http.MethodGet, "/favicon.ico", nil)
598	router.ServeHTTP(w, r)
599	if !mfs.opened {
600		t.Error("serving file failed")
601	}
602}
603