1package chi
2
3import (
4	"bytes"
5	"context"
6	"fmt"
7	"io"
8	"io/ioutil"
9	"net"
10	"net/http"
11	"net/http/httptest"
12	"os"
13	"sync"
14	"testing"
15	"time"
16)
17
18func TestMuxBasic(t *testing.T) {
19	var count uint64
20	countermw := func(next http.Handler) http.Handler {
21		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22			count++
23			next.ServeHTTP(w, r)
24		})
25	}
26
27	usermw := func(next http.Handler) http.Handler {
28		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29			ctx := r.Context()
30			ctx = context.WithValue(ctx, ctxKey{"user"}, "peter")
31			r = r.WithContext(ctx)
32			next.ServeHTTP(w, r)
33		})
34	}
35
36	exmw := func(next http.Handler) http.Handler {
37		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
38			ctx := context.WithValue(r.Context(), ctxKey{"ex"}, "a")
39			r = r.WithContext(ctx)
40			next.ServeHTTP(w, r)
41		})
42	}
43
44	logbuf := bytes.NewBufferString("")
45	logmsg := "logmw test"
46	logmw := func(next http.Handler) http.Handler {
47		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48			logbuf.WriteString(logmsg)
49			next.ServeHTTP(w, r)
50		})
51	}
52
53	cxindex := func(w http.ResponseWriter, r *http.Request) {
54		ctx := r.Context()
55		user := ctx.Value(ctxKey{"user"}).(string)
56		w.WriteHeader(200)
57		w.Write([]byte(fmt.Sprintf("hi %s", user)))
58	}
59
60	ping := func(w http.ResponseWriter, r *http.Request) {
61		w.WriteHeader(200)
62		w.Write([]byte("."))
63	}
64
65	headPing := func(w http.ResponseWriter, r *http.Request) {
66		w.Header().Set("X-Ping", "1")
67		w.WriteHeader(200)
68	}
69
70	createPing := func(w http.ResponseWriter, r *http.Request) {
71		// create ....
72		w.WriteHeader(201)
73	}
74
75	pingAll := func(w http.ResponseWriter, r *http.Request) {
76		w.WriteHeader(200)
77		w.Write([]byte("ping all"))
78	}
79
80	pingAll2 := func(w http.ResponseWriter, r *http.Request) {
81		w.WriteHeader(200)
82		w.Write([]byte("ping all2"))
83	}
84
85	pingOne := func(w http.ResponseWriter, r *http.Request) {
86		idParam := URLParam(r, "id")
87		w.WriteHeader(200)
88		w.Write([]byte(fmt.Sprintf("ping one id: %s", idParam)))
89	}
90
91	pingWoop := func(w http.ResponseWriter, r *http.Request) {
92		w.WriteHeader(200)
93		w.Write([]byte("woop." + URLParam(r, "iidd")))
94	}
95
96	catchAll := func(w http.ResponseWriter, r *http.Request) {
97		w.WriteHeader(200)
98		w.Write([]byte("catchall"))
99	}
100
101	m := NewRouter()
102	m.Use(countermw)
103	m.Use(usermw)
104	m.Use(exmw)
105	m.Use(logmw)
106	m.Get("/", cxindex)
107	m.Method("GET", "/ping", http.HandlerFunc(ping))
108	m.MethodFunc("GET", "/pingall", pingAll)
109	m.MethodFunc("get", "/ping/all", pingAll)
110	m.Get("/ping/all2", pingAll2)
111
112	m.Head("/ping", headPing)
113	m.Post("/ping", createPing)
114	m.Get("/ping/{id}", pingWoop)
115	m.Get("/ping/{id}", pingOne) // expected to overwrite to pingOne handler
116	m.Get("/ping/{iidd}/woop", pingWoop)
117	m.HandleFunc("/admin/*", catchAll)
118	// m.Post("/admin/*", catchAll)
119
120	ts := httptest.NewServer(m)
121	defer ts.Close()
122
123	// GET /
124	if _, body := testRequest(t, ts, "GET", "/", nil); body != "hi peter" {
125		t.Fatalf(body)
126	}
127	tlogmsg, _ := logbuf.ReadString(0)
128	if tlogmsg != logmsg {
129		t.Error("expecting log message from middleware:", logmsg)
130	}
131
132	// GET /ping
133	if _, body := testRequest(t, ts, "GET", "/ping", nil); body != "." {
134		t.Fatalf(body)
135	}
136
137	// GET /pingall
138	if _, body := testRequest(t, ts, "GET", "/pingall", nil); body != "ping all" {
139		t.Fatalf(body)
140	}
141
142	// GET /ping/all
143	if _, body := testRequest(t, ts, "GET", "/ping/all", nil); body != "ping all" {
144		t.Fatalf(body)
145	}
146
147	// GET /ping/all2
148	if _, body := testRequest(t, ts, "GET", "/ping/all2", nil); body != "ping all2" {
149		t.Fatalf(body)
150	}
151
152	// GET /ping/123
153	if _, body := testRequest(t, ts, "GET", "/ping/123", nil); body != "ping one id: 123" {
154		t.Fatalf(body)
155	}
156
157	// GET /ping/allan
158	if _, body := testRequest(t, ts, "GET", "/ping/allan", nil); body != "ping one id: allan" {
159		t.Fatalf(body)
160	}
161
162	// GET /ping/1/woop
163	if _, body := testRequest(t, ts, "GET", "/ping/1/woop", nil); body != "woop.1" {
164		t.Fatalf(body)
165	}
166
167	// HEAD /ping
168	resp, err := http.Head(ts.URL + "/ping")
169	if err != nil {
170		t.Fatal(err)
171	}
172	if resp.StatusCode != 200 {
173		t.Error("head failed, should be 200")
174	}
175	if resp.Header.Get("X-Ping") == "" {
176		t.Error("expecting X-Ping header")
177	}
178
179	// GET /admin/catch-this
180	if _, body := testRequest(t, ts, "GET", "/admin/catch-thazzzzz", nil); body != "catchall" {
181		t.Fatalf(body)
182	}
183
184	// POST /admin/catch-this
185	resp, err = http.Post(ts.URL+"/admin/casdfsadfs", "text/plain", bytes.NewReader([]byte{}))
186	if err != nil {
187		t.Fatal(err)
188	}
189
190	body, err := ioutil.ReadAll(resp.Body)
191	if err != nil {
192		t.Fatal(err)
193	}
194	defer resp.Body.Close()
195
196	if resp.StatusCode != 200 {
197		t.Error("POST failed, should be 200")
198	}
199
200	if string(body) != "catchall" {
201		t.Error("expecting response body: 'catchall'")
202	}
203
204	// Custom http method DIE /ping/1/woop
205	if resp, body := testRequest(t, ts, "DIE", "/ping/1/woop", nil); body != "" || resp.StatusCode != 405 {
206		t.Fatalf(fmt.Sprintf("expecting 405 status and empty body, got %d '%s'", resp.StatusCode, body))
207	}
208}
209
210func TestMuxMounts(t *testing.T) {
211	r := NewRouter()
212
213	r.Get("/{hash}", func(w http.ResponseWriter, r *http.Request) {
214		v := URLParam(r, "hash")
215		w.Write([]byte(fmt.Sprintf("/%s", v)))
216	})
217
218	r.Route("/{hash}/share", func(r Router) {
219		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
220			v := URLParam(r, "hash")
221			w.Write([]byte(fmt.Sprintf("/%s/share", v)))
222		})
223		r.Get("/{network}", func(w http.ResponseWriter, r *http.Request) {
224			v := URLParam(r, "hash")
225			n := URLParam(r, "network")
226			w.Write([]byte(fmt.Sprintf("/%s/share/%s", v, n)))
227		})
228	})
229
230	m := NewRouter()
231	m.Mount("/sharing", r)
232
233	ts := httptest.NewServer(m)
234	defer ts.Close()
235
236	if _, body := testRequest(t, ts, "GET", "/sharing/aBc", nil); body != "/aBc" {
237		t.Fatalf(body)
238	}
239	if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share", nil); body != "/aBc/share" {
240		t.Fatalf(body)
241	}
242	if _, body := testRequest(t, ts, "GET", "/sharing/aBc/share/twitter", nil); body != "/aBc/share/twitter" {
243		t.Fatalf(body)
244	}
245}
246
247func TestMuxPlain(t *testing.T) {
248	r := NewRouter()
249	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
250		w.Write([]byte("bye"))
251	})
252	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
253		w.WriteHeader(404)
254		w.Write([]byte("nothing here"))
255	})
256
257	ts := httptest.NewServer(r)
258	defer ts.Close()
259
260	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
261		t.Fatalf(body)
262	}
263	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
264		t.Fatalf(body)
265	}
266}
267
268func TestMuxEmptyRoutes(t *testing.T) {
269	mux := NewRouter()
270
271	apiRouter := NewRouter()
272	// oops, we forgot to declare any route handlers
273
274	mux.Handle("/api*", apiRouter)
275
276	if _, body := testHandler(t, mux, "GET", "/", nil); body != "404 page not found\n" {
277		t.Fatalf(body)
278	}
279
280	if _, body := testHandler(t, apiRouter, "GET", "/", nil); body != "404 page not found\n" {
281		t.Fatalf(body)
282	}
283}
284
285// Test a mux that routes a trailing slash, see also middleware/strip_test.go
286// for an example of using a middleware to handle trailing slashes.
287func TestMuxTrailingSlash(t *testing.T) {
288	r := NewRouter()
289	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
290		w.WriteHeader(404)
291		w.Write([]byte("nothing here"))
292	})
293
294	subRoutes := NewRouter()
295	indexHandler := func(w http.ResponseWriter, r *http.Request) {
296		accountID := URLParam(r, "accountID")
297		w.Write([]byte(accountID))
298	}
299	subRoutes.Get("/", indexHandler)
300
301	r.Mount("/accounts/{accountID}", subRoutes)
302	r.Get("/accounts/{accountID}/", indexHandler)
303
304	ts := httptest.NewServer(r)
305	defer ts.Close()
306
307	if _, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" {
308		t.Fatalf(body)
309	}
310	if _, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" {
311		t.Fatalf(body)
312	}
313	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" {
314		t.Fatalf(body)
315	}
316}
317
318func TestMuxNestedNotFound(t *testing.T) {
319	r := NewRouter()
320
321	r.Use(func(next http.Handler) http.Handler {
322		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
323			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw"}, "mw"))
324			next.ServeHTTP(w, r)
325		})
326	})
327
328	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
329		w.Write([]byte("bye"))
330	})
331
332	r.With(func(next http.Handler) http.Handler {
333		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
334			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"with"}, "with"))
335			next.ServeHTTP(w, r)
336		})
337	}).NotFound(func(w http.ResponseWriter, r *http.Request) {
338		chkMw := r.Context().Value(ctxKey{"mw"}).(string)
339		chkWith := r.Context().Value(ctxKey{"with"}).(string)
340		w.WriteHeader(404)
341		w.Write([]byte(fmt.Sprintf("root 404 %s %s", chkMw, chkWith)))
342	})
343
344	sr1 := NewRouter()
345
346	sr1.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
347		w.Write([]byte("sub"))
348	})
349	sr1.Group(func(sr1 Router) {
350		sr1.Use(func(next http.Handler) http.Handler {
351			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
352				r = r.WithContext(context.WithValue(r.Context(), ctxKey{"mw2"}, "mw2"))
353				next.ServeHTTP(w, r)
354			})
355		})
356		sr1.NotFound(func(w http.ResponseWriter, r *http.Request) {
357			chkMw2 := r.Context().Value(ctxKey{"mw2"}).(string)
358			w.WriteHeader(404)
359			w.Write([]byte(fmt.Sprintf("sub 404 %s", chkMw2)))
360		})
361	})
362
363	sr2 := NewRouter()
364	sr2.Get("/sub", func(w http.ResponseWriter, r *http.Request) {
365		w.Write([]byte("sub2"))
366	})
367
368	r.Mount("/admin1", sr1)
369	r.Mount("/admin2", sr2)
370
371	ts := httptest.NewServer(r)
372	defer ts.Close()
373
374	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
375		t.Fatalf(body)
376	}
377	if _, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "root 404 mw with" {
378		t.Fatalf(body)
379	}
380	if _, body := testRequest(t, ts, "GET", "/admin1/sub", nil); body != "sub" {
381		t.Fatalf(body)
382	}
383	if _, body := testRequest(t, ts, "GET", "/admin1/nope", nil); body != "sub 404 mw2" {
384		t.Fatalf(body)
385	}
386	if _, body := testRequest(t, ts, "GET", "/admin2/sub", nil); body != "sub2" {
387		t.Fatalf(body)
388	}
389
390	// Not found pages should bubble up to the root.
391	if _, body := testRequest(t, ts, "GET", "/admin2/nope", nil); body != "root 404 mw with" {
392		t.Fatalf(body)
393	}
394}
395
396func TestMuxNestedMethodNotAllowed(t *testing.T) {
397	r := NewRouter()
398	r.Get("/root", func(w http.ResponseWriter, r *http.Request) {
399		w.Write([]byte("root"))
400	})
401	r.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
402		w.WriteHeader(405)
403		w.Write([]byte("root 405"))
404	})
405
406	sr1 := NewRouter()
407	sr1.Get("/sub1", func(w http.ResponseWriter, r *http.Request) {
408		w.Write([]byte("sub1"))
409	})
410	sr1.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
411		w.WriteHeader(405)
412		w.Write([]byte("sub1 405"))
413	})
414
415	sr2 := NewRouter()
416	sr2.Get("/sub2", func(w http.ResponseWriter, r *http.Request) {
417		w.Write([]byte("sub2"))
418	})
419
420	pathVar := NewRouter()
421	pathVar.Get("/{var}", func(w http.ResponseWriter, r *http.Request) {
422		w.Write([]byte("pv"))
423	})
424	pathVar.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
425		w.WriteHeader(405)
426		w.Write([]byte("pv 405"))
427	})
428
429	r.Mount("/prefix1", sr1)
430	r.Mount("/prefix2", sr2)
431	r.Mount("/pathVar", pathVar)
432
433	ts := httptest.NewServer(r)
434	defer ts.Close()
435
436	if _, body := testRequest(t, ts, "GET", "/root", nil); body != "root" {
437		t.Fatalf(body)
438	}
439	if _, body := testRequest(t, ts, "PUT", "/root", nil); body != "root 405" {
440		t.Fatalf(body)
441	}
442	if _, body := testRequest(t, ts, "GET", "/prefix1/sub1", nil); body != "sub1" {
443		t.Fatalf(body)
444	}
445	if _, body := testRequest(t, ts, "PUT", "/prefix1/sub1", nil); body != "sub1 405" {
446		t.Fatalf(body)
447	}
448	if _, body := testRequest(t, ts, "GET", "/prefix2/sub2", nil); body != "sub2" {
449		t.Fatalf(body)
450	}
451	if _, body := testRequest(t, ts, "PUT", "/prefix2/sub2", nil); body != "root 405" {
452		t.Fatalf(body)
453	}
454	if _, body := testRequest(t, ts, "GET", "/pathVar/myvar", nil); body != "pv" {
455		t.Fatalf(body)
456	}
457	if _, body := testRequest(t, ts, "DELETE", "/pathVar/myvar", nil); body != "pv 405" {
458		t.Fatalf(body)
459	}
460}
461
462func TestMuxComplicatedNotFound(t *testing.T) {
463	decorateRouter := func(r *Mux) {
464		// Root router with groups
465		r.Get("/auth", func(w http.ResponseWriter, r *http.Request) {
466			w.Write([]byte("auth get"))
467		})
468		r.Route("/public", func(r Router) {
469			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
470				w.Write([]byte("public get"))
471			})
472		})
473
474		// sub router with groups
475		sub0 := NewRouter()
476		sub0.Route("/resource", func(r Router) {
477			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
478				w.Write([]byte("private get"))
479			})
480		})
481		r.Mount("/private", sub0)
482
483		// sub router with groups
484		sub1 := NewRouter()
485		sub1.Route("/resource", func(r Router) {
486			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
487				w.Write([]byte("private get"))
488			})
489		})
490		r.With(func(next http.Handler) http.Handler { return next }).Mount("/private_mw", sub1)
491	}
492
493	testNotFound := func(t *testing.T, r *Mux) {
494		ts := httptest.NewServer(r)
495		defer ts.Close()
496
497		// check that we didn't break correct routes
498		if _, body := testRequest(t, ts, "GET", "/auth", nil); body != "auth get" {
499			t.Fatalf(body)
500		}
501		if _, body := testRequest(t, ts, "GET", "/public", nil); body != "public get" {
502			t.Fatalf(body)
503		}
504		if _, body := testRequest(t, ts, "GET", "/public/", nil); body != "public get" {
505			t.Fatalf(body)
506		}
507		if _, body := testRequest(t, ts, "GET", "/private/resource", nil); body != "private get" {
508			t.Fatalf(body)
509		}
510		// check custom not-found on all levels
511		if _, body := testRequest(t, ts, "GET", "/nope", nil); body != "custom not-found" {
512			t.Fatalf(body)
513		}
514		if _, body := testRequest(t, ts, "GET", "/public/nope", nil); body != "custom not-found" {
515			t.Fatalf(body)
516		}
517		if _, body := testRequest(t, ts, "GET", "/private/nope", nil); body != "custom not-found" {
518			t.Fatalf(body)
519		}
520		if _, body := testRequest(t, ts, "GET", "/private/resource/nope", nil); body != "custom not-found" {
521			t.Fatalf(body)
522		}
523		if _, body := testRequest(t, ts, "GET", "/private_mw/nope", nil); body != "custom not-found" {
524			t.Fatalf(body)
525		}
526		if _, body := testRequest(t, ts, "GET", "/private_mw/resource/nope", nil); body != "custom not-found" {
527			t.Fatalf(body)
528		}
529		// check custom not-found on trailing slash routes
530		if _, body := testRequest(t, ts, "GET", "/auth/", nil); body != "custom not-found" {
531			t.Fatalf(body)
532		}
533	}
534
535	t.Run("pre", func(t *testing.T) {
536		r := NewRouter()
537		r.NotFound(func(w http.ResponseWriter, r *http.Request) {
538			w.Write([]byte("custom not-found"))
539		})
540		decorateRouter(r)
541		testNotFound(t, r)
542	})
543
544	t.Run("post", func(t *testing.T) {
545		r := NewRouter()
546		decorateRouter(r)
547		r.NotFound(func(w http.ResponseWriter, r *http.Request) {
548			w.Write([]byte("custom not-found"))
549		})
550		testNotFound(t, r)
551	})
552}
553
554func TestMuxWith(t *testing.T) {
555	var cmwInit1, cmwHandler1 uint64
556	var cmwInit2, cmwHandler2 uint64
557	mw1 := func(next http.Handler) http.Handler {
558		cmwInit1++
559		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
560			cmwHandler1++
561			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline1"}, "yes"))
562			next.ServeHTTP(w, r)
563		})
564	}
565	mw2 := func(next http.Handler) http.Handler {
566		cmwInit2++
567		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
568			cmwHandler2++
569			r = r.WithContext(context.WithValue(r.Context(), ctxKey{"inline2"}, "yes"))
570			next.ServeHTTP(w, r)
571		})
572	}
573
574	r := NewRouter()
575	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
576		w.Write([]byte("bye"))
577	})
578	r.With(mw1).With(mw2).Get("/inline", func(w http.ResponseWriter, r *http.Request) {
579		v1 := r.Context().Value(ctxKey{"inline1"}).(string)
580		v2 := r.Context().Value(ctxKey{"inline2"}).(string)
581		w.Write([]byte(fmt.Sprintf("inline %s %s", v1, v2)))
582	})
583
584	ts := httptest.NewServer(r)
585	defer ts.Close()
586
587	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
588		t.Fatalf(body)
589	}
590	if _, body := testRequest(t, ts, "GET", "/inline", nil); body != "inline yes yes" {
591		t.Fatalf(body)
592	}
593	if cmwInit1 != 1 {
594		t.Fatalf("expecting cmwInit1 to be 1, got %d", cmwInit1)
595	}
596	if cmwHandler1 != 1 {
597		t.Fatalf("expecting cmwHandler1 to be 1, got %d", cmwHandler1)
598	}
599	if cmwInit2 != 1 {
600		t.Fatalf("expecting cmwInit2 to be 1, got %d", cmwInit2)
601	}
602	if cmwHandler2 != 1 {
603		t.Fatalf("expecting cmwHandler2 to be 1, got %d", cmwHandler2)
604	}
605}
606
607func TestRouterFromMuxWith(t *testing.T) {
608	t.Parallel()
609
610	r := NewRouter()
611
612	with := r.With(func(next http.Handler) http.Handler {
613		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
614			next.ServeHTTP(w, r)
615		})
616	})
617
618	with.Get("/with_middleware", func(w http.ResponseWriter, r *http.Request) {})
619
620	ts := httptest.NewServer(with)
621	defer ts.Close()
622
623	// Without the fix this test was committed with, this causes a panic.
624	testRequest(t, ts, http.MethodGet, "/with_middleware", nil)
625}
626
627func TestMuxMiddlewareStack(t *testing.T) {
628	var stdmwInit, stdmwHandler uint64
629	stdmw := func(next http.Handler) http.Handler {
630		stdmwInit++
631		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
632			stdmwHandler++
633			next.ServeHTTP(w, r)
634		})
635	}
636	_ = stdmw
637
638	var ctxmwInit, ctxmwHandler uint64
639	ctxmw := func(next http.Handler) http.Handler {
640		ctxmwInit++
641		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
642			ctxmwHandler++
643			ctx := r.Context()
644			ctx = context.WithValue(ctx, ctxKey{"count.ctxmwHandler"}, ctxmwHandler)
645			r = r.WithContext(ctx)
646			next.ServeHTTP(w, r)
647		})
648	}
649
650	var inCtxmwInit, inCtxmwHandler uint64
651	inCtxmw := func(next http.Handler) http.Handler {
652		inCtxmwInit++
653		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
654			inCtxmwHandler++
655			next.ServeHTTP(w, r)
656		})
657	}
658
659	r := NewRouter()
660	r.Use(stdmw)
661	r.Use(ctxmw)
662	r.Use(func(next http.Handler) http.Handler {
663		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
664			if r.URL.Path == "/ping" {
665				w.Write([]byte("pong"))
666				return
667			}
668			next.ServeHTTP(w, r)
669		})
670	})
671
672	var handlerCount uint64
673
674	r.With(inCtxmw).Get("/", func(w http.ResponseWriter, r *http.Request) {
675		handlerCount++
676		ctx := r.Context()
677		ctxmwHandlerCount := ctx.Value(ctxKey{"count.ctxmwHandler"}).(uint64)
678		w.Write([]byte(fmt.Sprintf("inits:%d reqs:%d ctxValue:%d", ctxmwInit, handlerCount, ctxmwHandlerCount)))
679	})
680
681	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
682		w.Write([]byte("wooot"))
683	})
684
685	ts := httptest.NewServer(r)
686	defer ts.Close()
687
688	testRequest(t, ts, "GET", "/", nil)
689	testRequest(t, ts, "GET", "/", nil)
690	var body string
691	_, body = testRequest(t, ts, "GET", "/", nil)
692	if body != "inits:1 reqs:3 ctxValue:3" {
693		t.Fatalf("got: '%s'", body)
694	}
695
696	_, body = testRequest(t, ts, "GET", "/ping", nil)
697	if body != "pong" {
698		t.Fatalf("got: '%s'", body)
699	}
700}
701
702func TestMuxRouteGroups(t *testing.T) {
703	var stdmwInit, stdmwHandler uint64
704
705	stdmw := func(next http.Handler) http.Handler {
706		stdmwInit++
707		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
708			stdmwHandler++
709			next.ServeHTTP(w, r)
710		})
711	}
712
713	var stdmwInit2, stdmwHandler2 uint64
714	stdmw2 := func(next http.Handler) http.Handler {
715		stdmwInit2++
716		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
717			stdmwHandler2++
718			next.ServeHTTP(w, r)
719		})
720	}
721
722	r := NewRouter()
723	r.Group(func(r Router) {
724		r.Use(stdmw)
725		r.Get("/group", func(w http.ResponseWriter, r *http.Request) {
726			w.Write([]byte("root group"))
727		})
728	})
729	r.Group(func(r Router) {
730		r.Use(stdmw2)
731		r.Get("/group2", func(w http.ResponseWriter, r *http.Request) {
732			w.Write([]byte("root group2"))
733		})
734	})
735
736	ts := httptest.NewServer(r)
737	defer ts.Close()
738
739	// GET /group
740	_, body := testRequest(t, ts, "GET", "/group", nil)
741	if body != "root group" {
742		t.Fatalf("got: '%s'", body)
743	}
744	if stdmwInit != 1 || stdmwHandler != 1 {
745		t.Logf("stdmw counters failed, should be 1:1, got %d:%d", stdmwInit, stdmwHandler)
746	}
747
748	// GET /group2
749	_, body = testRequest(t, ts, "GET", "/group2", nil)
750	if body != "root group2" {
751		t.Fatalf("got: '%s'", body)
752	}
753	if stdmwInit2 != 1 || stdmwHandler2 != 1 {
754		t.Fatalf("stdmw2 counters failed, should be 1:1, got %d:%d", stdmwInit2, stdmwHandler2)
755	}
756}
757
758func TestMuxBig(t *testing.T) {
759	r := bigMux()
760
761	ts := httptest.NewServer(r)
762	defer ts.Close()
763
764	var body, expected string
765
766	_, body = testRequest(t, ts, "GET", "/favicon.ico", nil)
767	if body != "fav" {
768		t.Fatalf("got '%s'", body)
769	}
770	_, body = testRequest(t, ts, "GET", "/hubs/4/view", nil)
771	if body != "/hubs/4/view reqid:1 session:anonymous" {
772		t.Fatalf("got '%v'", body)
773	}
774	_, body = testRequest(t, ts, "GET", "/hubs/4/view/index.html", nil)
775	if body != "/hubs/4/view/index.html reqid:1 session:anonymous" {
776		t.Fatalf("got '%s'", body)
777	}
778	_, body = testRequest(t, ts, "POST", "/hubs/ethereumhub/view/index.html", nil)
779	if body != "/hubs/ethereumhub/view/index.html reqid:1 session:anonymous" {
780		t.Fatalf("got '%s'", body)
781	}
782	_, body = testRequest(t, ts, "GET", "/", nil)
783	if body != "/ reqid:1 session:elvis" {
784		t.Fatalf("got '%s'", body)
785	}
786	_, body = testRequest(t, ts, "GET", "/suggestions", nil)
787	if body != "/suggestions reqid:1 session:elvis" {
788		t.Fatalf("got '%s'", body)
789	}
790	_, body = testRequest(t, ts, "GET", "/woot/444/hiiii", nil)
791	if body != "/woot/444/hiiii" {
792		t.Fatalf("got '%s'", body)
793	}
794	_, body = testRequest(t, ts, "GET", "/hubs/123", nil)
795	expected = "/hubs/123 reqid:1 session:elvis"
796	if body != expected {
797		t.Fatalf("expected:%s got:%s", expected, body)
798	}
799	_, body = testRequest(t, ts, "GET", "/hubs/123/touch", nil)
800	if body != "/hubs/123/touch reqid:1 session:elvis" {
801		t.Fatalf("got '%s'", body)
802	}
803	_, body = testRequest(t, ts, "GET", "/hubs/123/webhooks", nil)
804	if body != "/hubs/123/webhooks reqid:1 session:elvis" {
805		t.Fatalf("got '%s'", body)
806	}
807	_, body = testRequest(t, ts, "GET", "/hubs/123/posts", nil)
808	if body != "/hubs/123/posts reqid:1 session:elvis" {
809		t.Fatalf("got '%s'", body)
810	}
811	_, body = testRequest(t, ts, "GET", "/folders", nil)
812	if body != "404 page not found\n" {
813		t.Fatalf("got '%s'", body)
814	}
815	_, body = testRequest(t, ts, "GET", "/folders/", nil)
816	if body != "/folders/ reqid:1 session:elvis" {
817		t.Fatalf("got '%s'", body)
818	}
819	_, body = testRequest(t, ts, "GET", "/folders/public", nil)
820	if body != "/folders/public reqid:1 session:elvis" {
821		t.Fatalf("got '%s'", body)
822	}
823	_, body = testRequest(t, ts, "GET", "/folders/nothing", nil)
824	if body != "404 page not found\n" {
825		t.Fatalf("got '%s'", body)
826	}
827}
828
829func bigMux() Router {
830	var r *Mux
831	var sr3 *Mux
832	// var sr1, sr2, sr3, sr4, sr5, sr6 *Mux
833	r = NewRouter()
834	r.Use(func(next http.Handler) http.Handler {
835		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
836			ctx := context.WithValue(r.Context(), ctxKey{"requestID"}, "1")
837			next.ServeHTTP(w, r.WithContext(ctx))
838		})
839	})
840	r.Use(func(next http.Handler) http.Handler {
841		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
842			next.ServeHTTP(w, r)
843		})
844	})
845	r.Group(func(r Router) {
846		r.Use(func(next http.Handler) http.Handler {
847			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
848				ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "anonymous")
849				next.ServeHTTP(w, r.WithContext(ctx))
850			})
851		})
852		r.Get("/favicon.ico", func(w http.ResponseWriter, r *http.Request) {
853			w.Write([]byte("fav"))
854		})
855		r.Get("/hubs/{hubID}/view", func(w http.ResponseWriter, r *http.Request) {
856			ctx := r.Context()
857			s := fmt.Sprintf("/hubs/%s/view reqid:%s session:%s", URLParam(r, "hubID"),
858				ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
859			w.Write([]byte(s))
860		})
861		r.Get("/hubs/{hubID}/view/*", func(w http.ResponseWriter, r *http.Request) {
862			ctx := r.Context()
863			s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubID"),
864				URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
865			w.Write([]byte(s))
866		})
867		r.Post("/hubs/{hubSlug}/view/*", func(w http.ResponseWriter, r *http.Request) {
868			ctx := r.Context()
869			s := fmt.Sprintf("/hubs/%s/view/%s reqid:%s session:%s", URLParamFromCtx(ctx, "hubSlug"),
870				URLParam(r, "*"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
871			w.Write([]byte(s))
872		})
873	})
874	r.Group(func(r Router) {
875		r.Use(func(next http.Handler) http.Handler {
876			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
877				ctx := context.WithValue(r.Context(), ctxKey{"session.user"}, "elvis")
878				next.ServeHTTP(w, r.WithContext(ctx))
879			})
880		})
881		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
882			ctx := r.Context()
883			s := fmt.Sprintf("/ reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
884			w.Write([]byte(s))
885		})
886		r.Get("/suggestions", func(w http.ResponseWriter, r *http.Request) {
887			ctx := r.Context()
888			s := fmt.Sprintf("/suggestions reqid:%s session:%s", ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
889			w.Write([]byte(s))
890		})
891
892		r.Get("/woot/{wootID}/*", func(w http.ResponseWriter, r *http.Request) {
893			s := fmt.Sprintf("/woot/%s/%s", URLParam(r, "wootID"), URLParam(r, "*"))
894			w.Write([]byte(s))
895		})
896
897		r.Route("/hubs", func(r Router) {
898			_ = r.(*Mux) // sr1
899			r.Route("/{hubID}", func(r Router) {
900				_ = r.(*Mux) // sr2
901				r.Get("/", func(w http.ResponseWriter, r *http.Request) {
902					ctx := r.Context()
903					s := fmt.Sprintf("/hubs/%s reqid:%s session:%s",
904						URLParam(r, "hubID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
905					w.Write([]byte(s))
906				})
907				r.Get("/touch", func(w http.ResponseWriter, r *http.Request) {
908					ctx := r.Context()
909					s := fmt.Sprintf("/hubs/%s/touch reqid:%s session:%s", URLParam(r, "hubID"),
910						ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
911					w.Write([]byte(s))
912				})
913
914				sr3 = NewRouter()
915				sr3.Get("/", func(w http.ResponseWriter, r *http.Request) {
916					ctx := r.Context()
917					s := fmt.Sprintf("/hubs/%s/webhooks reqid:%s session:%s", URLParam(r, "hubID"),
918						ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
919					w.Write([]byte(s))
920				})
921				sr3.Route("/{webhookID}", func(r Router) {
922					_ = r.(*Mux) // sr4
923					r.Get("/", func(w http.ResponseWriter, r *http.Request) {
924						ctx := r.Context()
925						s := fmt.Sprintf("/hubs/%s/webhooks/%s reqid:%s session:%s", URLParam(r, "hubID"),
926							URLParam(r, "webhookID"), ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
927						w.Write([]byte(s))
928					})
929				})
930
931				r.Mount("/webhooks", Chain(func(next http.Handler) http.Handler {
932					return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
933						next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), ctxKey{"hook"}, true)))
934					})
935				}).Handler(sr3))
936
937				r.Route("/posts", func(r Router) {
938					_ = r.(*Mux) // sr5
939					r.Get("/", func(w http.ResponseWriter, r *http.Request) {
940						ctx := r.Context()
941						s := fmt.Sprintf("/hubs/%s/posts reqid:%s session:%s", URLParam(r, "hubID"),
942							ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
943						w.Write([]byte(s))
944					})
945				})
946			})
947		})
948
949		r.Route("/folders/", func(r Router) {
950			_ = r.(*Mux) // sr6
951			r.Get("/", func(w http.ResponseWriter, r *http.Request) {
952				ctx := r.Context()
953				s := fmt.Sprintf("/folders/ reqid:%s session:%s",
954					ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
955				w.Write([]byte(s))
956			})
957			r.Get("/public", func(w http.ResponseWriter, r *http.Request) {
958				ctx := r.Context()
959				s := fmt.Sprintf("/folders/public reqid:%s session:%s",
960					ctx.Value(ctxKey{"requestID"}), ctx.Value(ctxKey{"session.user"}))
961				w.Write([]byte(s))
962			})
963		})
964	})
965
966	return r
967}
968
969func TestMuxSubroutesBasic(t *testing.T) {
970	hIndex := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
971		w.Write([]byte("index"))
972	})
973	hArticlesList := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
974		w.Write([]byte("articles-list"))
975	})
976	hSearchArticles := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
977		w.Write([]byte("search-articles"))
978	})
979	hGetArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
980		w.Write([]byte(fmt.Sprintf("get-article:%s", URLParam(r, "id"))))
981	})
982	hSyncArticle := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
983		w.Write([]byte(fmt.Sprintf("sync-article:%s", URLParam(r, "id"))))
984	})
985
986	r := NewRouter()
987	// var rr1, rr2 *Mux
988	r.Get("/", hIndex)
989	r.Route("/articles", func(r Router) {
990		// rr1 = r.(*Mux)
991		r.Get("/", hArticlesList)
992		r.Get("/search", hSearchArticles)
993		r.Route("/{id}", func(r Router) {
994			// rr2 = r.(*Mux)
995			r.Get("/", hGetArticle)
996			r.Get("/sync", hSyncArticle)
997		})
998	})
999
1000	// log.Println("~~~~~~~~~")
1001	// log.Println("~~~~~~~~~")
1002	// debugPrintTree(0, 0, r.tree, 0)
1003	// log.Println("~~~~~~~~~")
1004	// log.Println("~~~~~~~~~")
1005
1006	// log.Println("~~~~~~~~~")
1007	// log.Println("~~~~~~~~~")
1008	// debugPrintTree(0, 0, rr1.tree, 0)
1009	// log.Println("~~~~~~~~~")
1010	// log.Println("~~~~~~~~~")
1011
1012	// log.Println("~~~~~~~~~")
1013	// log.Println("~~~~~~~~~")
1014	// debugPrintTree(0, 0, rr2.tree, 0)
1015	// log.Println("~~~~~~~~~")
1016	// log.Println("~~~~~~~~~")
1017
1018	ts := httptest.NewServer(r)
1019	defer ts.Close()
1020
1021	var body, expected string
1022
1023	_, body = testRequest(t, ts, "GET", "/", nil)
1024	expected = "index"
1025	if body != expected {
1026		t.Fatalf("expected:%s got:%s", expected, body)
1027	}
1028	_, body = testRequest(t, ts, "GET", "/articles", nil)
1029	expected = "articles-list"
1030	if body != expected {
1031		t.Fatalf("expected:%s got:%s", expected, body)
1032	}
1033	_, body = testRequest(t, ts, "GET", "/articles/search", nil)
1034	expected = "search-articles"
1035	if body != expected {
1036		t.Fatalf("expected:%s got:%s", expected, body)
1037	}
1038	_, body = testRequest(t, ts, "GET", "/articles/123", nil)
1039	expected = "get-article:123"
1040	if body != expected {
1041		t.Fatalf("expected:%s got:%s", expected, body)
1042	}
1043	_, body = testRequest(t, ts, "GET", "/articles/123/sync", nil)
1044	expected = "sync-article:123"
1045	if body != expected {
1046		t.Fatalf("expected:%s got:%s", expected, body)
1047	}
1048}
1049
1050func TestMuxSubroutes(t *testing.T) {
1051	hHubView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1052		w.Write([]byte("hub1"))
1053	})
1054	hHubView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1055		w.Write([]byte("hub2"))
1056	})
1057	hHubView3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1058		w.Write([]byte("hub3"))
1059	})
1060	hAccountView1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1061		w.Write([]byte("account1"))
1062	})
1063	hAccountView2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1064		w.Write([]byte("account2"))
1065	})
1066
1067	r := NewRouter()
1068	r.Get("/hubs/{hubID}/view", hHubView1)
1069	r.Get("/hubs/{hubID}/view/*", hHubView2)
1070
1071	sr := NewRouter()
1072	sr.Get("/", hHubView3)
1073	r.Mount("/hubs/{hubID}/users", sr)
1074	r.Get("/hubs/{hubID}/users/", func(w http.ResponseWriter, r *http.Request) {
1075		w.Write([]byte("hub3 override"))
1076	})
1077
1078	sr3 := NewRouter()
1079	sr3.Get("/", hAccountView1)
1080	sr3.Get("/hi", hAccountView2)
1081
1082	// var sr2 *Mux
1083	r.Route("/accounts/{accountID}", func(r Router) {
1084		_ = r.(*Mux) // sr2
1085		// r.Get("/", hAccountView1)
1086		r.Mount("/", sr3)
1087	})
1088
1089	// This is the same as the r.Route() call mounted on sr2
1090	// sr2 := NewRouter()
1091	// sr2.Mount("/", sr3)
1092	// r.Mount("/accounts/{accountID}", sr2)
1093
1094	ts := httptest.NewServer(r)
1095	defer ts.Close()
1096
1097	var body, expected string
1098
1099	_, body = testRequest(t, ts, "GET", "/hubs/123/view", nil)
1100	expected = "hub1"
1101	if body != expected {
1102		t.Fatalf("expected:%s got:%s", expected, body)
1103	}
1104	_, body = testRequest(t, ts, "GET", "/hubs/123/view/index.html", nil)
1105	expected = "hub2"
1106	if body != expected {
1107		t.Fatalf("expected:%s got:%s", expected, body)
1108	}
1109	_, body = testRequest(t, ts, "GET", "/hubs/123/users", nil)
1110	expected = "hub3"
1111	if body != expected {
1112		t.Fatalf("expected:%s got:%s", expected, body)
1113	}
1114	_, body = testRequest(t, ts, "GET", "/hubs/123/users/", nil)
1115	expected = "hub3 override"
1116	if body != expected {
1117		t.Fatalf("expected:%s got:%s", expected, body)
1118	}
1119	_, body = testRequest(t, ts, "GET", "/accounts/44", nil)
1120	expected = "account1"
1121	if body != expected {
1122		t.Fatalf("request:%s expected:%s got:%s", "GET /accounts/44", expected, body)
1123	}
1124	_, body = testRequest(t, ts, "GET", "/accounts/44/hi", nil)
1125	expected = "account2"
1126	if body != expected {
1127		t.Fatalf("expected:%s got:%s", expected, body)
1128	}
1129
1130	// Test that we're building the routingPatterns properly
1131	router := r
1132	req, _ := http.NewRequest("GET", "/accounts/44/hi", nil)
1133
1134	rctx := NewRouteContext()
1135	req = req.WithContext(context.WithValue(req.Context(), RouteCtxKey, rctx))
1136
1137	w := httptest.NewRecorder()
1138	router.ServeHTTP(w, req)
1139
1140	body = w.Body.String()
1141	expected = "account2"
1142	if body != expected {
1143		t.Fatalf("expected:%s got:%s", expected, body)
1144	}
1145
1146	routePatterns := rctx.RoutePatterns
1147	if len(rctx.RoutePatterns) != 3 {
1148		t.Fatalf("expected 3 routing patterns, got:%d", len(rctx.RoutePatterns))
1149	}
1150	expected = "/accounts/{accountID}/*"
1151	if routePatterns[0] != expected {
1152		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[0])
1153	}
1154	expected = "/*"
1155	if routePatterns[1] != expected {
1156		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[1])
1157	}
1158	expected = "/hi"
1159	if routePatterns[2] != expected {
1160		t.Fatalf("routePattern, expected:%s got:%s", expected, routePatterns[2])
1161	}
1162
1163}
1164
1165func TestSingleHandler(t *testing.T) {
1166	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1167		name := URLParam(r, "name")
1168		w.Write([]byte("hi " + name))
1169	})
1170
1171	r, _ := http.NewRequest("GET", "/", nil)
1172	rctx := NewRouteContext()
1173	r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
1174	rctx.URLParams.Add("name", "joe")
1175
1176	w := httptest.NewRecorder()
1177	h.ServeHTTP(w, r)
1178
1179	body := w.Body.String()
1180	expected := "hi joe"
1181	if body != expected {
1182		t.Fatalf("expected:%s got:%s", expected, body)
1183	}
1184}
1185
1186// TODO: a Router wrapper test..
1187//
1188// type ACLMux struct {
1189// 	*Mux
1190// 	XX string
1191// }
1192//
1193// func NewACLMux() *ACLMux {
1194// 	return &ACLMux{Mux: NewRouter(), XX: "hihi"}
1195// }
1196//
1197// // TODO: this should be supported...
1198// func TestWoot(t *testing.T) {
1199// 	var r Router = NewRouter()
1200//
1201// 	var r2 Router = NewACLMux() //NewRouter()
1202// 	r2.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1203// 		w.Write([]byte("hi"))
1204// 	})
1205//
1206// 	r.Mount("/", r2)
1207// }
1208
1209func TestServeHTTPExistingContext(t *testing.T) {
1210	r := NewRouter()
1211	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1212		s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
1213		w.Write([]byte(s))
1214	})
1215	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
1216		s, _ := r.Context().Value(ctxKey{"testCtx"}).(string)
1217		w.WriteHeader(404)
1218		w.Write([]byte(s))
1219	})
1220
1221	testcases := []struct {
1222		Method         string
1223		Path           string
1224		Ctx            context.Context
1225		ExpectedStatus int
1226		ExpectedBody   string
1227	}{
1228		{
1229			Method:         "GET",
1230			Path:           "/hi",
1231			Ctx:            context.WithValue(context.Background(), ctxKey{"testCtx"}, "hi ctx"),
1232			ExpectedStatus: 200,
1233			ExpectedBody:   "hi ctx",
1234		},
1235		{
1236			Method:         "GET",
1237			Path:           "/hello",
1238			Ctx:            context.WithValue(context.Background(), ctxKey{"testCtx"}, "nothing here ctx"),
1239			ExpectedStatus: 404,
1240			ExpectedBody:   "nothing here ctx",
1241		},
1242	}
1243
1244	for _, tc := range testcases {
1245		resp := httptest.NewRecorder()
1246		req, err := http.NewRequest(tc.Method, tc.Path, nil)
1247		if err != nil {
1248			t.Fatalf("%v", err)
1249		}
1250		req = req.WithContext(tc.Ctx)
1251		r.ServeHTTP(resp, req)
1252		b, err := ioutil.ReadAll(resp.Body)
1253		if err != nil {
1254			t.Fatalf("%v", err)
1255		}
1256		if resp.Code != tc.ExpectedStatus {
1257			t.Fatalf("%v != %v", tc.ExpectedStatus, resp.Code)
1258		}
1259		if string(b) != tc.ExpectedBody {
1260			t.Fatalf("%s != %s", tc.ExpectedBody, b)
1261		}
1262	}
1263}
1264
1265func TestNestedGroups(t *testing.T) {
1266	handlerPrintCounter := func(w http.ResponseWriter, r *http.Request) {
1267		counter, _ := r.Context().Value(ctxKey{"counter"}).(int)
1268		w.Write([]byte(fmt.Sprintf("%v", counter)))
1269	}
1270
1271	mwIncreaseCounter := func(next http.Handler) http.Handler {
1272		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1273			ctx := r.Context()
1274			counter, _ := ctx.Value(ctxKey{"counter"}).(int)
1275			counter++
1276			ctx = context.WithValue(ctx, ctxKey{"counter"}, counter)
1277			next.ServeHTTP(w, r.WithContext(ctx))
1278		})
1279	}
1280
1281	// Each route represents value of its counter (number of applied middlewares).
1282	r := NewRouter() // counter == 0
1283	r.Get("/0", handlerPrintCounter)
1284	r.Group(func(r Router) {
1285		r.Use(mwIncreaseCounter) // counter == 1
1286		r.Get("/1", handlerPrintCounter)
1287
1288		// r.Handle(GET, "/2", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
1289		r.With(mwIncreaseCounter).Get("/2", handlerPrintCounter)
1290
1291		r.Group(func(r Router) {
1292			r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
1293			r.Get("/3", handlerPrintCounter)
1294		})
1295		r.Route("/", func(r Router) {
1296			r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 3
1297
1298			// r.Handle(GET, "/4", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
1299			r.With(mwIncreaseCounter).Get("/4", handlerPrintCounter)
1300
1301			r.Group(func(r Router) {
1302				r.Use(mwIncreaseCounter, mwIncreaseCounter) // counter == 5
1303				r.Get("/5", handlerPrintCounter)
1304				// r.Handle(GET, "/6", Chain(mwIncreaseCounter).HandlerFunc(handlerPrintCounter))
1305				r.With(mwIncreaseCounter).Get("/6", handlerPrintCounter)
1306
1307			})
1308		})
1309	})
1310
1311	ts := httptest.NewServer(r)
1312	defer ts.Close()
1313
1314	for _, route := range []string{"0", "1", "2", "3", "4", "5", "6"} {
1315		if _, body := testRequest(t, ts, "GET", "/"+route, nil); body != route {
1316			t.Errorf("expected %v, got %v", route, body)
1317		}
1318	}
1319}
1320
1321func TestMiddlewarePanicOnLateUse(t *testing.T) {
1322	handler := func(w http.ResponseWriter, r *http.Request) {
1323		w.Write([]byte("hello\n"))
1324	}
1325
1326	mw := func(next http.Handler) http.Handler {
1327		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1328			next.ServeHTTP(w, r)
1329		})
1330	}
1331
1332	defer func() {
1333		if recover() == nil {
1334			t.Error("expected panic()")
1335		}
1336	}()
1337
1338	r := NewRouter()
1339	r.Get("/", handler)
1340	r.Use(mw) // Too late to apply middleware, we're expecting panic().
1341}
1342
1343func TestMountingExistingPath(t *testing.T) {
1344	handler := func(w http.ResponseWriter, r *http.Request) {}
1345
1346	defer func() {
1347		if recover() == nil {
1348			t.Error("expected panic()")
1349		}
1350	}()
1351
1352	r := NewRouter()
1353	r.Get("/", handler)
1354	r.Mount("/hi", http.HandlerFunc(handler))
1355	r.Mount("/hi", http.HandlerFunc(handler))
1356}
1357
1358func TestMountingSimilarPattern(t *testing.T) {
1359	r := NewRouter()
1360	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1361		w.Write([]byte("bye"))
1362	})
1363
1364	r2 := NewRouter()
1365	r2.Get("/", func(w http.ResponseWriter, r *http.Request) {
1366		w.Write([]byte("foobar"))
1367	})
1368
1369	r3 := NewRouter()
1370	r3.Get("/", func(w http.ResponseWriter, r *http.Request) {
1371		w.Write([]byte("foo"))
1372	})
1373
1374	r.Mount("/foobar", r2)
1375	r.Mount("/foo", r3)
1376
1377	ts := httptest.NewServer(r)
1378	defer ts.Close()
1379
1380	if _, body := testRequest(t, ts, "GET", "/hi", nil); body != "bye" {
1381		t.Fatalf(body)
1382	}
1383}
1384
1385func TestMuxEmptyParams(t *testing.T) {
1386	r := NewRouter()
1387	r.Get(`/users/{x}/{y}/{z}`, func(w http.ResponseWriter, r *http.Request) {
1388		x := URLParam(r, "x")
1389		y := URLParam(r, "y")
1390		z := URLParam(r, "z")
1391		w.Write([]byte(fmt.Sprintf("%s-%s-%s", x, y, z)))
1392	})
1393
1394	ts := httptest.NewServer(r)
1395	defer ts.Close()
1396
1397	if _, body := testRequest(t, ts, "GET", "/users/a/b/c", nil); body != "a-b-c" {
1398		t.Fatalf(body)
1399	}
1400	if _, body := testRequest(t, ts, "GET", "/users///c", nil); body != "--c" {
1401		t.Fatalf(body)
1402	}
1403}
1404
1405func TestMuxMissingParams(t *testing.T) {
1406	r := NewRouter()
1407	r.Get(`/user/{userId:\d+}`, func(w http.ResponseWriter, r *http.Request) {
1408		userID := URLParam(r, "userId")
1409		w.Write([]byte(fmt.Sprintf("userId = '%s'", userID)))
1410	})
1411	r.NotFound(func(w http.ResponseWriter, r *http.Request) {
1412		w.WriteHeader(404)
1413		w.Write([]byte("nothing here"))
1414	})
1415
1416	ts := httptest.NewServer(r)
1417	defer ts.Close()
1418
1419	if _, body := testRequest(t, ts, "GET", "/user/123", nil); body != "userId = '123'" {
1420		t.Fatalf(body)
1421	}
1422	if _, body := testRequest(t, ts, "GET", "/user/", nil); body != "nothing here" {
1423		t.Fatalf(body)
1424	}
1425}
1426
1427func TestMuxWildcardRoute(t *testing.T) {
1428	handler := func(w http.ResponseWriter, r *http.Request) {}
1429
1430	defer func() {
1431		if recover() == nil {
1432			t.Error("expected panic()")
1433		}
1434	}()
1435
1436	r := NewRouter()
1437	r.Get("/*/wildcard/must/be/at/end", handler)
1438}
1439
1440func TestMuxWildcardRouteCheckTwo(t *testing.T) {
1441	handler := func(w http.ResponseWriter, r *http.Request) {}
1442
1443	defer func() {
1444		if recover() == nil {
1445			t.Error("expected panic()")
1446		}
1447	}()
1448
1449	r := NewRouter()
1450	r.Get("/*/wildcard/{must}/be/at/end", handler)
1451}
1452
1453func TestMuxRegexp(t *testing.T) {
1454	r := NewRouter()
1455	r.Route("/{param:[0-9]*}/test", func(r Router) {
1456		r.Get("/", func(w http.ResponseWriter, r *http.Request) {
1457			w.Write([]byte(fmt.Sprintf("Hi: %s", URLParam(r, "param"))))
1458		})
1459	})
1460
1461	ts := httptest.NewServer(r)
1462	defer ts.Close()
1463
1464	if _, body := testRequest(t, ts, "GET", "//test", nil); body != "Hi: " {
1465		t.Fatalf(body)
1466	}
1467}
1468
1469func TestMuxRegexp2(t *testing.T) {
1470	r := NewRouter()
1471	r.Get("/foo-{suffix:[a-z]{2,3}}.json", func(w http.ResponseWriter, r *http.Request) {
1472		w.Write([]byte(URLParam(r, "suffix")))
1473	})
1474	ts := httptest.NewServer(r)
1475	defer ts.Close()
1476
1477	if _, body := testRequest(t, ts, "GET", "/foo-.json", nil); body != "" {
1478		t.Fatalf(body)
1479	}
1480	if _, body := testRequest(t, ts, "GET", "/foo-abc.json", nil); body != "abc" {
1481		t.Fatalf(body)
1482	}
1483}
1484
1485func TestMuxRegexp3(t *testing.T) {
1486	r := NewRouter()
1487	r.Get("/one/{firstId:[a-z0-9-]+}/{secondId:[a-z]+}/first", func(w http.ResponseWriter, r *http.Request) {
1488		w.Write([]byte("first"))
1489	})
1490	r.Get("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
1491		w.Write([]byte("second"))
1492	})
1493	r.Delete("/one/{firstId:[a-z0-9-_]+}/{secondId:[0-9]+}/second", func(w http.ResponseWriter, r *http.Request) {
1494		w.Write([]byte("third"))
1495	})
1496
1497	r.Route("/one", func(r Router) {
1498		r.Get("/{dns:[a-z-0-9_]+}", func(writer http.ResponseWriter, request *http.Request) {
1499			writer.Write([]byte("_"))
1500		})
1501		r.Get("/{dns:[a-z-0-9_]+}/info", func(writer http.ResponseWriter, request *http.Request) {
1502			writer.Write([]byte("_"))
1503		})
1504		r.Delete("/{id:[0-9]+}", func(writer http.ResponseWriter, request *http.Request) {
1505			writer.Write([]byte("forth"))
1506		})
1507	})
1508
1509	ts := httptest.NewServer(r)
1510	defer ts.Close()
1511
1512	if _, body := testRequest(t, ts, "GET", "/one/hello/peter/first", nil); body != "first" {
1513		t.Fatalf(body)
1514	}
1515	if _, body := testRequest(t, ts, "GET", "/one/hithere/123/second", nil); body != "second" {
1516		t.Fatalf(body)
1517	}
1518	if _, body := testRequest(t, ts, "DELETE", "/one/hithere/123/second", nil); body != "third" {
1519		t.Fatalf(body)
1520	}
1521	if _, body := testRequest(t, ts, "DELETE", "/one/123", nil); body != "forth" {
1522		t.Fatalf(body)
1523	}
1524}
1525
1526func TestMuxSubrouterWildcardParam(t *testing.T) {
1527	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1528		fmt.Fprintf(w, "param:%v *:%v", URLParam(r, "param"), URLParam(r, "*"))
1529	})
1530
1531	r := NewRouter()
1532
1533	r.Get("/bare/{param}", h)
1534	r.Get("/bare/{param}/*", h)
1535
1536	r.Route("/case0", func(r Router) {
1537		r.Get("/{param}", h)
1538		r.Get("/{param}/*", h)
1539	})
1540
1541	ts := httptest.NewServer(r)
1542	defer ts.Close()
1543
1544	if _, body := testRequest(t, ts, "GET", "/bare/hi", nil); body != "param:hi *:" {
1545		t.Fatalf(body)
1546	}
1547	if _, body := testRequest(t, ts, "GET", "/bare/hi/yes", nil); body != "param:hi *:yes" {
1548		t.Fatalf(body)
1549	}
1550	if _, body := testRequest(t, ts, "GET", "/case0/hi", nil); body != "param:hi *:" {
1551		t.Fatalf(body)
1552	}
1553	if _, body := testRequest(t, ts, "GET", "/case0/hi/yes", nil); body != "param:hi *:yes" {
1554		t.Fatalf(body)
1555	}
1556}
1557
1558func TestMuxContextIsThreadSafe(t *testing.T) {
1559	router := NewRouter()
1560	router.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1561		ctx, cancel := context.WithTimeout(r.Context(), 1*time.Millisecond)
1562		defer cancel()
1563
1564		<-ctx.Done()
1565	})
1566
1567	wg := sync.WaitGroup{}
1568
1569	for i := 0; i < 100; i++ {
1570		wg.Add(1)
1571		go func() {
1572			defer wg.Done()
1573			for j := 0; j < 10000; j++ {
1574				w := httptest.NewRecorder()
1575				r, err := http.NewRequest("GET", "/ok", nil)
1576				if err != nil {
1577					t.Fatal(err)
1578				}
1579
1580				ctx, cancel := context.WithCancel(r.Context())
1581				r = r.WithContext(ctx)
1582
1583				go func() {
1584					cancel()
1585				}()
1586				router.ServeHTTP(w, r)
1587			}
1588		}()
1589	}
1590	wg.Wait()
1591}
1592
1593func TestEscapedURLParams(t *testing.T) {
1594	m := NewRouter()
1595	m.Get("/api/{identifier}/{region}/{size}/{rotation}/*", func(w http.ResponseWriter, r *http.Request) {
1596		w.WriteHeader(200)
1597		rctx := RouteContext(r.Context())
1598		if rctx == nil {
1599			t.Error("no context")
1600			return
1601		}
1602		identifier := URLParam(r, "identifier")
1603		if identifier != "http:%2f%2fexample.com%2fimage.png" {
1604			t.Errorf("identifier path parameter incorrect %s", identifier)
1605			return
1606		}
1607		region := URLParam(r, "region")
1608		if region != "full" {
1609			t.Errorf("region path parameter incorrect %s", region)
1610			return
1611		}
1612		size := URLParam(r, "size")
1613		if size != "max" {
1614			t.Errorf("size path parameter incorrect %s", size)
1615			return
1616		}
1617		rotation := URLParam(r, "rotation")
1618		if rotation != "0" {
1619			t.Errorf("rotation path parameter incorrect %s", rotation)
1620			return
1621		}
1622		w.Write([]byte("success"))
1623	})
1624
1625	ts := httptest.NewServer(m)
1626	defer ts.Close()
1627
1628	if _, body := testRequest(t, ts, "GET", "/api/http:%2f%2fexample.com%2fimage.png/full/max/0/color.png", nil); body != "success" {
1629		t.Fatalf(body)
1630	}
1631}
1632
1633func TestMuxMatch(t *testing.T) {
1634	r := NewRouter()
1635	r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
1636		w.Header().Set("X-Test", "yes")
1637		w.Write([]byte("bye"))
1638	})
1639	r.Route("/articles", func(r Router) {
1640		r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1641			id := URLParam(r, "id")
1642			w.Header().Set("X-Article", id)
1643			w.Write([]byte("article:" + id))
1644		})
1645	})
1646	r.Route("/users", func(r Router) {
1647		r.Head("/{id}", func(w http.ResponseWriter, r *http.Request) {
1648			w.Header().Set("X-User", "-")
1649			w.Write([]byte("user"))
1650		})
1651		r.Get("/{id}", func(w http.ResponseWriter, r *http.Request) {
1652			id := URLParam(r, "id")
1653			w.Header().Set("X-User", id)
1654			w.Write([]byte("user:" + id))
1655		})
1656	})
1657
1658	tctx := NewRouteContext()
1659
1660	tctx.Reset()
1661	if r.Match(tctx, "GET", "/users/1") == false {
1662		t.Fatal("expecting to find match for route:", "GET", "/users/1")
1663	}
1664
1665	tctx.Reset()
1666	if r.Match(tctx, "HEAD", "/articles/10") == true {
1667		t.Fatal("not expecting to find match for route:", "HEAD", "/articles/10")
1668	}
1669}
1670
1671func TestServerBaseContext(t *testing.T) {
1672	r := NewRouter()
1673	r.Get("/", func(w http.ResponseWriter, r *http.Request) {
1674		baseYes := r.Context().Value(ctxKey{"base"}).(string)
1675		if _, ok := r.Context().Value(http.ServerContextKey).(*http.Server); !ok {
1676			panic("missing server context")
1677		}
1678		if _, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr); !ok {
1679			panic("missing local addr context")
1680		}
1681		w.Write([]byte(baseYes))
1682	})
1683
1684	// Setup http Server with a base context
1685	ctx := context.WithValue(context.Background(), ctxKey{"base"}, "yes")
1686	ts := httptest.NewUnstartedServer(r)
1687	ts.Config.BaseContext = func(_ net.Listener) context.Context {
1688		return ctx
1689	}
1690	ts.Start()
1691
1692	defer ts.Close()
1693
1694	if _, body := testRequest(t, ts, "GET", "/", nil); body != "yes" {
1695		t.Fatalf(body)
1696	}
1697}
1698
1699func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) {
1700	req, err := http.NewRequest(method, ts.URL+path, body)
1701	if err != nil {
1702		t.Fatal(err)
1703		return nil, ""
1704	}
1705
1706	resp, err := http.DefaultClient.Do(req)
1707	if err != nil {
1708		t.Fatal(err)
1709		return nil, ""
1710	}
1711
1712	respBody, err := ioutil.ReadAll(resp.Body)
1713	if err != nil {
1714		t.Fatal(err)
1715		return nil, ""
1716	}
1717	defer resp.Body.Close()
1718
1719	return resp, string(respBody)
1720}
1721
1722func testHandler(t *testing.T, h http.Handler, method, path string, body io.Reader) (*http.Response, string) {
1723	r, _ := http.NewRequest(method, path, body)
1724	w := httptest.NewRecorder()
1725	h.ServeHTTP(w, r)
1726	return w.Result(), w.Body.String()
1727}
1728
1729type testFileSystem struct {
1730	open func(name string) (http.File, error)
1731}
1732
1733func (fs *testFileSystem) Open(name string) (http.File, error) {
1734	return fs.open(name)
1735}
1736
1737type testFile struct {
1738	name     string
1739	contents []byte
1740}
1741
1742func (tf *testFile) Close() error {
1743	return nil
1744}
1745
1746func (tf *testFile) Read(p []byte) (n int, err error) {
1747	copy(p, tf.contents)
1748	return len(p), nil
1749}
1750
1751func (tf *testFile) Seek(offset int64, whence int) (int64, error) {
1752	return 0, nil
1753}
1754
1755func (tf *testFile) Readdir(count int) ([]os.FileInfo, error) {
1756	stat, _ := tf.Stat()
1757	return []os.FileInfo{stat}, nil
1758}
1759
1760func (tf *testFile) Stat() (os.FileInfo, error) {
1761	return &testFileInfo{tf.name, int64(len(tf.contents))}, nil
1762}
1763
1764type testFileInfo struct {
1765	name string
1766	size int64
1767}
1768
1769func (tfi *testFileInfo) Name() string       { return tfi.name }
1770func (tfi *testFileInfo) Size() int64        { return tfi.size }
1771func (tfi *testFileInfo) Mode() os.FileMode  { return 0755 }
1772func (tfi *testFileInfo) ModTime() time.Time { return time.Now() }
1773func (tfi *testFileInfo) IsDir() bool        { return false }
1774func (tfi *testFileInfo) Sys() interface{}   { return nil }
1775
1776type ctxKey struct {
1777	name string
1778}
1779
1780func (k ctxKey) String() string {
1781	return "context value " + k.name
1782}
1783
1784func BenchmarkMux(b *testing.B) {
1785	h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1786	h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1787	h3 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1788	h4 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1789	h5 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1790	h6 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
1791
1792	mx := NewRouter()
1793	mx.Get("/", h1)
1794	mx.Get("/hi", h2)
1795	mx.Get("/sup/{id}/and/{this}", h3)
1796	mx.Get("/sup/{id}/{bar:foo}/{this}", h3)
1797
1798	mx.Route("/sharing/{x}/{hash}", func(mx Router) {
1799		mx.Get("/", h4)          // subrouter-1
1800		mx.Get("/{network}", h5) // subrouter-1
1801		mx.Get("/twitter", h5)
1802		mx.Route("/direct", func(mx Router) {
1803			mx.Get("/", h6) // subrouter-2
1804			mx.Get("/download", h6)
1805		})
1806	})
1807
1808	routes := []string{
1809		"/",
1810		"/hi",
1811		"/sup/123/and/this",
1812		"/sup/123/foo/this",
1813		"/sharing/z/aBc",                 // subrouter-1
1814		"/sharing/z/aBc/twitter",         // subrouter-1
1815		"/sharing/z/aBc/direct",          // subrouter-2
1816		"/sharing/z/aBc/direct/download", // subrouter-2
1817	}
1818
1819	for _, path := range routes {
1820		b.Run("route:"+path, func(b *testing.B) {
1821			w := httptest.NewRecorder()
1822			r, _ := http.NewRequest("GET", path, nil)
1823
1824			b.ReportAllocs()
1825			b.ResetTimer()
1826
1827			for i := 0; i < b.N; i++ {
1828				mx.ServeHTTP(w, r)
1829			}
1830		})
1831	}
1832}
1833