1package chi
2
3import (
4	"context"
5	"fmt"
6	"net/http"
7	"strings"
8	"sync"
9)
10
11var _ Router = &Mux{}
12
13// Mux is a simple HTTP route multiplexer that parses a request path,
14// records any URL params, and executes an end handler. It implements
15// the http.Handler interface and is friendly with the standard library.
16//
17// Mux is designed to be fast, minimal and offer a powerful API for building
18// modular and composable HTTP services with a large set of handlers. It's
19// particularly useful for writing large REST API services that break a handler
20// into many smaller parts composed of middlewares and end handlers.
21type Mux struct {
22	// The computed mux handler made of the chained middleware stack and
23	// the tree router
24	handler http.Handler
25
26	// The radix trie router
27	tree *node
28
29	// Custom method not allowed handler
30	methodNotAllowedHandler http.HandlerFunc
31
32	// Controls the behaviour of middleware chain generation when a mux
33	// is registered as an inline group inside another mux.
34	parent *Mux
35
36	// Routing context pool
37	pool *sync.Pool
38
39	// Custom route not found handler
40	notFoundHandler http.HandlerFunc
41
42	// The middleware stack
43	middlewares []func(http.Handler) http.Handler
44
45	inline bool
46}
47
48// NewMux returns a newly initialized Mux object that implements the Router
49// interface.
50func NewMux() *Mux {
51	mux := &Mux{tree: &node{}, pool: &sync.Pool{}}
52	mux.pool.New = func() interface{} {
53		return NewRouteContext()
54	}
55	return mux
56}
57
58// ServeHTTP is the single method of the http.Handler interface that makes
59// Mux interoperable with the standard library. It uses a sync.Pool to get and
60// reuse routing contexts for each request.
61func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
62	// Ensure the mux has some routes defined on the mux
63	if mx.handler == nil {
64		mx.NotFoundHandler().ServeHTTP(w, r)
65		return
66	}
67
68	// Check if a routing context already exists from a parent router.
69	rctx, _ := r.Context().Value(RouteCtxKey).(*Context)
70	if rctx != nil {
71		mx.handler.ServeHTTP(w, r)
72		return
73	}
74
75	// Fetch a RouteContext object from the sync pool, and call the computed
76	// mx.handler that is comprised of mx.middlewares + mx.routeHTTP.
77	// Once the request is finished, reset the routing context and put it back
78	// into the pool for reuse from another request.
79	rctx = mx.pool.Get().(*Context)
80	rctx.Reset()
81	rctx.Routes = mx
82	rctx.parentCtx = r.Context()
83
84	// NOTE: r.WithContext() causes 2 allocations and context.WithValue() causes 1 allocation
85	r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
86
87	// Serve the request and once its done, put the request context back in the sync pool
88	mx.handler.ServeHTTP(w, r)
89	mx.pool.Put(rctx)
90}
91
92// Use appends a middleware handler to the Mux middleware stack.
93//
94// The middleware stack for any Mux will execute before searching for a matching
95// route to a specific handler, which provides opportunity to respond early,
96// change the course of the request execution, or set request-scoped values for
97// the next http.Handler.
98func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) {
99	if mx.handler != nil {
100		panic("chi: all middlewares must be defined before routes on a mux")
101	}
102	mx.middlewares = append(mx.middlewares, middlewares...)
103}
104
105// Handle adds the route `pattern` that matches any http method to
106// execute the `handler` http.Handler.
107func (mx *Mux) Handle(pattern string, handler http.Handler) {
108	mx.handle(mALL, pattern, handler)
109}
110
111// HandleFunc adds the route `pattern` that matches any http method to
112// execute the `handlerFn` http.HandlerFunc.
113func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) {
114	mx.handle(mALL, pattern, handlerFn)
115}
116
117// Method adds the route `pattern` that matches `method` http method to
118// execute the `handler` http.Handler.
119func (mx *Mux) Method(method, pattern string, handler http.Handler) {
120	m, ok := methodMap[strings.ToUpper(method)]
121	if !ok {
122		panic(fmt.Sprintf("chi: '%s' http method is not supported.", method))
123	}
124	mx.handle(m, pattern, handler)
125}
126
127// MethodFunc adds the route `pattern` that matches `method` http method to
128// execute the `handlerFn` http.HandlerFunc.
129func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) {
130	mx.Method(method, pattern, handlerFn)
131}
132
133// Connect adds the route `pattern` that matches a CONNECT http method to
134// execute the `handlerFn` http.HandlerFunc.
135func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) {
136	mx.handle(mCONNECT, pattern, handlerFn)
137}
138
139// Delete adds the route `pattern` that matches a DELETE http method to
140// execute the `handlerFn` http.HandlerFunc.
141func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) {
142	mx.handle(mDELETE, pattern, handlerFn)
143}
144
145// Get adds the route `pattern` that matches a GET http method to
146// execute the `handlerFn` http.HandlerFunc.
147func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) {
148	mx.handle(mGET, pattern, handlerFn)
149}
150
151// Head adds the route `pattern` that matches a HEAD http method to
152// execute the `handlerFn` http.HandlerFunc.
153func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) {
154	mx.handle(mHEAD, pattern, handlerFn)
155}
156
157// Options adds the route `pattern` that matches a OPTIONS http method to
158// execute the `handlerFn` http.HandlerFunc.
159func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) {
160	mx.handle(mOPTIONS, pattern, handlerFn)
161}
162
163// Patch adds the route `pattern` that matches a PATCH http method to
164// execute the `handlerFn` http.HandlerFunc.
165func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) {
166	mx.handle(mPATCH, pattern, handlerFn)
167}
168
169// Post adds the route `pattern` that matches a POST http method to
170// execute the `handlerFn` http.HandlerFunc.
171func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) {
172	mx.handle(mPOST, pattern, handlerFn)
173}
174
175// Put adds the route `pattern` that matches a PUT http method to
176// execute the `handlerFn` http.HandlerFunc.
177func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) {
178	mx.handle(mPUT, pattern, handlerFn)
179}
180
181// Trace adds the route `pattern` that matches a TRACE http method to
182// execute the `handlerFn` http.HandlerFunc.
183func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) {
184	mx.handle(mTRACE, pattern, handlerFn)
185}
186
187// NotFound sets a custom http.HandlerFunc for routing paths that could
188// not be found. The default 404 handler is `http.NotFound`.
189func (mx *Mux) NotFound(handlerFn http.HandlerFunc) {
190	// Build NotFound handler chain
191	m := mx
192	hFn := handlerFn
193	if mx.inline && mx.parent != nil {
194		m = mx.parent
195		hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
196	}
197
198	// Update the notFoundHandler from this point forward
199	m.notFoundHandler = hFn
200	m.updateSubRoutes(func(subMux *Mux) {
201		if subMux.notFoundHandler == nil {
202			subMux.NotFound(hFn)
203		}
204	})
205}
206
207// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the
208// method is unresolved. The default handler returns a 405 with an empty body.
209func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) {
210	// Build MethodNotAllowed handler chain
211	m := mx
212	hFn := handlerFn
213	if mx.inline && mx.parent != nil {
214		m = mx.parent
215		hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
216	}
217
218	// Update the methodNotAllowedHandler from this point forward
219	m.methodNotAllowedHandler = hFn
220	m.updateSubRoutes(func(subMux *Mux) {
221		if subMux.methodNotAllowedHandler == nil {
222			subMux.MethodNotAllowed(hFn)
223		}
224	})
225}
226
227// With adds inline middlewares for an endpoint handler.
228func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router {
229	// Similarly as in handle(), we must build the mux handler once additional
230	// middleware registration isn't allowed for this stack, like now.
231	if !mx.inline && mx.handler == nil {
232		mx.updateRouteHandler()
233	}
234
235	// Copy middlewares from parent inline muxs
236	var mws Middlewares
237	if mx.inline {
238		mws = make(Middlewares, len(mx.middlewares))
239		copy(mws, mx.middlewares)
240	}
241	mws = append(mws, middlewares...)
242
243	im := &Mux{
244		pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws,
245		notFoundHandler: mx.notFoundHandler, methodNotAllowedHandler: mx.methodNotAllowedHandler,
246	}
247
248	return im
249}
250
251// Group creates a new inline-Mux with a fresh middleware stack. It's useful
252// for a group of handlers along the same routing path that use an additional
253// set of middlewares. See _examples/.
254func (mx *Mux) Group(fn func(r Router)) Router {
255	im := mx.With().(*Mux)
256	if fn != nil {
257		fn(im)
258	}
259	return im
260}
261
262// Route creates a new Mux with a fresh middleware stack and mounts it
263// along the `pattern` as a subrouter. Effectively, this is a short-hand
264// call to Mount. See _examples/.
265func (mx *Mux) Route(pattern string, fn func(r Router)) Router {
266	if fn == nil {
267		panic(fmt.Sprintf("chi: attempting to Route() a nil subrouter on '%s'", pattern))
268	}
269	subRouter := NewRouter()
270	fn(subRouter)
271	mx.Mount(pattern, subRouter)
272	return subRouter
273}
274
275// Mount attaches another http.Handler or chi Router as a subrouter along a routing
276// path. It's very useful to split up a large API as many independent routers and
277// compose them as a single service using Mount. See _examples/.
278//
279// Note that Mount() simply sets a wildcard along the `pattern` that will continue
280// routing at the `handler`, which in most cases is another chi.Router. As a result,
281// if you define two Mount() routes on the exact same pattern the mount will panic.
282func (mx *Mux) Mount(pattern string, handler http.Handler) {
283	if handler == nil {
284		panic(fmt.Sprintf("chi: attempting to Mount() a nil handler on '%s'", pattern))
285	}
286
287	// Provide runtime safety for ensuring a pattern isn't mounted on an existing
288	// routing pattern.
289	if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") {
290		panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern))
291	}
292
293	// Assign sub-Router's with the parent not found & method not allowed handler if not specified.
294	subr, ok := handler.(*Mux)
295	if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil {
296		subr.NotFound(mx.notFoundHandler)
297	}
298	if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil {
299		subr.MethodNotAllowed(mx.methodNotAllowedHandler)
300	}
301
302	mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
303		rctx := RouteContext(r.Context())
304
305		// shift the url path past the previous subrouter
306		rctx.RoutePath = mx.nextRoutePath(rctx)
307
308		// reset the wildcard URLParam which connects the subrouter
309		n := len(rctx.URLParams.Keys) - 1
310		if n >= 0 && rctx.URLParams.Keys[n] == "*" && len(rctx.URLParams.Values) > n {
311			rctx.URLParams.Values[n] = ""
312		}
313
314		handler.ServeHTTP(w, r)
315	})
316
317	if pattern == "" || pattern[len(pattern)-1] != '/' {
318		mx.handle(mALL|mSTUB, pattern, mountHandler)
319		mx.handle(mALL|mSTUB, pattern+"/", mountHandler)
320		pattern += "/"
321	}
322
323	method := mALL
324	subroutes, _ := handler.(Routes)
325	if subroutes != nil {
326		method |= mSTUB
327	}
328	n := mx.handle(method, pattern+"*", mountHandler)
329
330	if subroutes != nil {
331		n.subroutes = subroutes
332	}
333}
334
335// Routes returns a slice of routing information from the tree,
336// useful for traversing available routes of a router.
337func (mx *Mux) Routes() []Route {
338	return mx.tree.routes()
339}
340
341// Middlewares returns a slice of middleware handler functions.
342func (mx *Mux) Middlewares() Middlewares {
343	return mx.middlewares
344}
345
346// Match searches the routing tree for a handler that matches the method/path.
347// It's similar to routing a http request, but without executing the handler
348// thereafter.
349//
350// Note: the *Context state is updated during execution, so manage
351// the state carefully or make a NewRouteContext().
352func (mx *Mux) Match(rctx *Context, method, path string) bool {
353	m, ok := methodMap[method]
354	if !ok {
355		return false
356	}
357
358	node, _, h := mx.tree.FindRoute(rctx, m, path)
359
360	if node != nil && node.subroutes != nil {
361		rctx.RoutePath = mx.nextRoutePath(rctx)
362		return node.subroutes.Match(rctx, method, rctx.RoutePath)
363	}
364
365	return h != nil
366}
367
368// NotFoundHandler returns the default Mux 404 responder whenever a route
369// cannot be found.
370func (mx *Mux) NotFoundHandler() http.HandlerFunc {
371	if mx.notFoundHandler != nil {
372		return mx.notFoundHandler
373	}
374	return http.NotFound
375}
376
377// MethodNotAllowedHandler returns the default Mux 405 responder whenever
378// a method cannot be resolved for a route.
379func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc {
380	if mx.methodNotAllowedHandler != nil {
381		return mx.methodNotAllowedHandler
382	}
383	return methodNotAllowedHandler
384}
385
386// handle registers a http.Handler in the routing tree for a particular http method
387// and routing pattern.
388func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node {
389	if len(pattern) == 0 || pattern[0] != '/' {
390		panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern))
391	}
392
393	// Build the computed routing handler for this routing pattern.
394	if !mx.inline && mx.handler == nil {
395		mx.updateRouteHandler()
396	}
397
398	// Build endpoint handler with inline middlewares for the route
399	var h http.Handler
400	if mx.inline {
401		mx.handler = http.HandlerFunc(mx.routeHTTP)
402		h = Chain(mx.middlewares...).Handler(handler)
403	} else {
404		h = handler
405	}
406
407	// Add the endpoint to the tree and return the node
408	return mx.tree.InsertRoute(method, pattern, h)
409}
410
411// routeHTTP routes a http.Request through the Mux routing tree to serve
412// the matching handler for a particular http method.
413func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
414	// Grab the route context object
415	rctx := r.Context().Value(RouteCtxKey).(*Context)
416
417	// The request routing path
418	routePath := rctx.RoutePath
419	if routePath == "" {
420		if r.URL.RawPath != "" {
421			routePath = r.URL.RawPath
422		} else {
423			routePath = r.URL.Path
424		}
425		if routePath == "" {
426			routePath = "/"
427		}
428	}
429
430	// Check if method is supported by chi
431	if rctx.RouteMethod == "" {
432		rctx.RouteMethod = r.Method
433	}
434	method, ok := methodMap[rctx.RouteMethod]
435	if !ok {
436		mx.MethodNotAllowedHandler().ServeHTTP(w, r)
437		return
438	}
439
440	// Find the route
441	if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil {
442		h.ServeHTTP(w, r)
443		return
444	}
445	if rctx.methodNotAllowed {
446		mx.MethodNotAllowedHandler().ServeHTTP(w, r)
447	} else {
448		mx.NotFoundHandler().ServeHTTP(w, r)
449	}
450}
451
452func (mx *Mux) nextRoutePath(rctx *Context) string {
453	routePath := "/"
454	nx := len(rctx.routeParams.Keys) - 1 // index of last param in list
455	if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx {
456		routePath = "/" + rctx.routeParams.Values[nx]
457	}
458	return routePath
459}
460
461// Recursively update data on child routers.
462func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) {
463	for _, r := range mx.tree.routes() {
464		subMux, ok := r.SubRoutes.(*Mux)
465		if !ok {
466			continue
467		}
468		fn(subMux)
469	}
470}
471
472// updateRouteHandler builds the single mux handler that is a chain of the middleware
473// stack, as defined by calls to Use(), and the tree router (Mux) itself. After this
474// point, no other middlewares can be registered on this Mux's stack. But you can still
475// compose additional middlewares via Group()'s or using a chained middleware handler.
476func (mx *Mux) updateRouteHandler() {
477	mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP))
478}
479
480// methodNotAllowedHandler is a helper function to respond with a 405,
481// method not allowed.
482func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
483	w.WriteHeader(405)
484	w.Write(nil)
485}
486