1package chi
2
3import (
4	"context"
5	"fmt"
6	"net/http"
7	"strings"
8	"sync"
9)
10
11var _ Router = &Mux{}
12
13// Mux is a simple HTTP route multiplexer that parses a request path,
14// records any URL params, and executes an end handler. It implements
15// the http.Handler interface and is friendly with the standard library.
16//
17// Mux is designed to be fast, minimal and offer a powerful API for building
18// modular and composable HTTP services with a large set of handlers. It's
19// particularly useful for writing large REST API services that break a handler
20// into many smaller parts composed of middlewares and end handlers.
21type Mux struct {
22	// The radix trie router
23	tree *node
24
25	// The middleware stack
26	middlewares []func(http.Handler) http.Handler
27
28	// Controls the behaviour of middleware chain generation when a mux
29	// is registered as an inline group inside another mux.
30	inline bool
31	parent *Mux
32
33	// The computed mux handler made of the chained middleware stack and
34	// the tree router
35	handler http.Handler
36
37	// Routing context pool
38	pool *sync.Pool
39
40	// Custom route not found handler
41	notFoundHandler http.HandlerFunc
42
43	// Custom method not allowed handler
44	methodNotAllowedHandler http.HandlerFunc
45}
46
47// NewMux returns a newly initialized Mux object that implements the Router
48// interface.
49func NewMux() *Mux {
50	mux := &Mux{tree: &node{}, pool: &sync.Pool{}}
51	mux.pool.New = func() interface{} {
52		return NewRouteContext()
53	}
54	return mux
55}
56
57// ServeHTTP is the single method of the http.Handler interface that makes
58// Mux interoperable with the standard library. It uses a sync.Pool to get and
59// reuse routing contexts for each request.
60func (mx *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
61	// Ensure the mux has some routes defined on the mux
62	if mx.handler == nil {
63		mx.NotFoundHandler().ServeHTTP(w, r)
64		return
65	}
66
67	// Check if a routing context already exists from a parent router.
68	rctx, _ := r.Context().Value(RouteCtxKey).(*Context)
69	if rctx != nil {
70		mx.handler.ServeHTTP(w, r)
71		return
72	}
73
74	// Fetch a RouteContext object from the sync pool, and call the computed
75	// mx.handler that is comprised of mx.middlewares + mx.routeHTTP.
76	// Once the request is finished, reset the routing context and put it back
77	// into the pool for reuse from another request.
78	rctx = mx.pool.Get().(*Context)
79	rctx.Reset()
80	rctx.Routes = mx
81	r = r.WithContext(context.WithValue(r.Context(), RouteCtxKey, rctx))
82	mx.handler.ServeHTTP(w, r)
83	mx.pool.Put(rctx)
84}
85
86// Use appends a middleware handler to the Mux middleware stack.
87//
88// The middleware stack for any Mux will execute before searching for a matching
89// route to a specific handler, which provides opportunity to respond early,
90// change the course of the request execution, or set request-scoped values for
91// the next http.Handler.
92func (mx *Mux) Use(middlewares ...func(http.Handler) http.Handler) {
93	if mx.handler != nil {
94		panic("chi: all middlewares must be defined before routes on a mux")
95	}
96	mx.middlewares = append(mx.middlewares, middlewares...)
97}
98
99// Handle adds the route `pattern` that matches any http method to
100// execute the `handler` http.Handler.
101func (mx *Mux) Handle(pattern string, handler http.Handler) {
102	mx.handle(mALL, pattern, handler)
103}
104
105// HandleFunc adds the route `pattern` that matches any http method to
106// execute the `handlerFn` http.HandlerFunc.
107func (mx *Mux) HandleFunc(pattern string, handlerFn http.HandlerFunc) {
108	mx.handle(mALL, pattern, handlerFn)
109}
110
111// Method adds the route `pattern` that matches `method` http method to
112// execute the `handler` http.Handler.
113func (mx *Mux) Method(method, pattern string, handler http.Handler) {
114	m, ok := methodMap[strings.ToUpper(method)]
115	if !ok {
116		panic(fmt.Sprintf("chi: '%s' http method is not supported.", method))
117	}
118	mx.handle(m, pattern, handler)
119}
120
121// MethodFunc adds the route `pattern` that matches `method` http method to
122// execute the `handlerFn` http.HandlerFunc.
123func (mx *Mux) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) {
124	mx.Method(method, pattern, handlerFn)
125}
126
127// Connect adds the route `pattern` that matches a CONNECT http method to
128// execute the `handlerFn` http.HandlerFunc.
129func (mx *Mux) Connect(pattern string, handlerFn http.HandlerFunc) {
130	mx.handle(mCONNECT, pattern, handlerFn)
131}
132
133// Delete adds the route `pattern` that matches a DELETE http method to
134// execute the `handlerFn` http.HandlerFunc.
135func (mx *Mux) Delete(pattern string, handlerFn http.HandlerFunc) {
136	mx.handle(mDELETE, pattern, handlerFn)
137}
138
139// Get adds the route `pattern` that matches a GET http method to
140// execute the `handlerFn` http.HandlerFunc.
141func (mx *Mux) Get(pattern string, handlerFn http.HandlerFunc) {
142	mx.handle(mGET, pattern, handlerFn)
143}
144
145// Head adds the route `pattern` that matches a HEAD http method to
146// execute the `handlerFn` http.HandlerFunc.
147func (mx *Mux) Head(pattern string, handlerFn http.HandlerFunc) {
148	mx.handle(mHEAD, pattern, handlerFn)
149}
150
151// Options adds the route `pattern` that matches a OPTIONS http method to
152// execute the `handlerFn` http.HandlerFunc.
153func (mx *Mux) Options(pattern string, handlerFn http.HandlerFunc) {
154	mx.handle(mOPTIONS, pattern, handlerFn)
155}
156
157// Patch adds the route `pattern` that matches a PATCH http method to
158// execute the `handlerFn` http.HandlerFunc.
159func (mx *Mux) Patch(pattern string, handlerFn http.HandlerFunc) {
160	mx.handle(mPATCH, pattern, handlerFn)
161}
162
163// Post adds the route `pattern` that matches a POST http method to
164// execute the `handlerFn` http.HandlerFunc.
165func (mx *Mux) Post(pattern string, handlerFn http.HandlerFunc) {
166	mx.handle(mPOST, pattern, handlerFn)
167}
168
169// Put adds the route `pattern` that matches a PUT http method to
170// execute the `handlerFn` http.HandlerFunc.
171func (mx *Mux) Put(pattern string, handlerFn http.HandlerFunc) {
172	mx.handle(mPUT, pattern, handlerFn)
173}
174
175// Trace adds the route `pattern` that matches a TRACE http method to
176// execute the `handlerFn` http.HandlerFunc.
177func (mx *Mux) Trace(pattern string, handlerFn http.HandlerFunc) {
178	mx.handle(mTRACE, pattern, handlerFn)
179}
180
181// NotFound sets a custom http.HandlerFunc for routing paths that could
182// not be found. The default 404 handler is `http.NotFound`.
183func (mx *Mux) NotFound(handlerFn http.HandlerFunc) {
184	// Build NotFound handler chain
185	m := mx
186	hFn := handlerFn
187	if mx.inline && mx.parent != nil {
188		m = mx.parent
189		hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
190	}
191
192	// Update the notFoundHandler from this point forward
193	m.notFoundHandler = hFn
194	m.updateSubRoutes(func(subMux *Mux) {
195		if subMux.notFoundHandler == nil {
196			subMux.NotFound(hFn)
197		}
198	})
199}
200
201// MethodNotAllowed sets a custom http.HandlerFunc for routing paths where the
202// method is unresolved. The default handler returns a 405 with an empty body.
203func (mx *Mux) MethodNotAllowed(handlerFn http.HandlerFunc) {
204	// Build MethodNotAllowed handler chain
205	m := mx
206	hFn := handlerFn
207	if mx.inline && mx.parent != nil {
208		m = mx.parent
209		hFn = Chain(mx.middlewares...).HandlerFunc(hFn).ServeHTTP
210	}
211
212	// Update the methodNotAllowedHandler from this point forward
213	m.methodNotAllowedHandler = hFn
214	m.updateSubRoutes(func(subMux *Mux) {
215		if subMux.methodNotAllowedHandler == nil {
216			subMux.MethodNotAllowed(hFn)
217		}
218	})
219}
220
221// With adds inline middlewares for an endpoint handler.
222func (mx *Mux) With(middlewares ...func(http.Handler) http.Handler) Router {
223	// Similarly as in handle(), we must build the mux handler once further
224	// middleware registration isn't allowed for this stack, like now.
225	if !mx.inline && mx.handler == nil {
226		mx.buildRouteHandler()
227	}
228
229	// Copy middlewares from parent inline muxs
230	var mws Middlewares
231	if mx.inline {
232		mws = make(Middlewares, len(mx.middlewares))
233		copy(mws, mx.middlewares)
234	}
235	mws = append(mws, middlewares...)
236
237	im := &Mux{pool: mx.pool, inline: true, parent: mx, tree: mx.tree, middlewares: mws}
238
239	return im
240}
241
242// Group creates a new inline-Mux with a fresh middleware stack. It's useful
243// for a group of handlers along the same routing path that use an additional
244// set of middlewares. See _examples/.
245func (mx *Mux) Group(fn func(r Router)) Router {
246	im := mx.With().(*Mux)
247	if fn != nil {
248		fn(im)
249	}
250	return im
251}
252
253// Route creates a new Mux with a fresh middleware stack and mounts it
254// along the `pattern` as a subrouter. Effectively, this is a short-hand
255// call to Mount. See _examples/.
256func (mx *Mux) Route(pattern string, fn func(r Router)) Router {
257	subRouter := NewRouter()
258	if fn != nil {
259		fn(subRouter)
260	}
261	mx.Mount(pattern, subRouter)
262	return subRouter
263}
264
265// Mount attaches another http.Handler or chi Router as a subrouter along a routing
266// path. It's very useful to split up a large API as many independent routers and
267// compose them as a single service using Mount. See _examples/.
268//
269// Note that Mount() simply sets a wildcard along the `pattern` that will continue
270// routing at the `handler`, which in most cases is another chi.Router. As a result,
271// if you define two Mount() routes on the exact same pattern the mount will panic.
272func (mx *Mux) Mount(pattern string, handler http.Handler) {
273	// Provide runtime safety for ensuring a pattern isn't mounted on an existing
274	// routing pattern.
275	if mx.tree.findPattern(pattern+"*") || mx.tree.findPattern(pattern+"/*") {
276		panic(fmt.Sprintf("chi: attempting to Mount() a handler on an existing path, '%s'", pattern))
277	}
278
279	// Assign sub-Router's with the parent not found & method not allowed handler if not specified.
280	subr, ok := handler.(*Mux)
281	if ok && subr.notFoundHandler == nil && mx.notFoundHandler != nil {
282		subr.NotFound(mx.notFoundHandler)
283	}
284	if ok && subr.methodNotAllowedHandler == nil && mx.methodNotAllowedHandler != nil {
285		subr.MethodNotAllowed(mx.methodNotAllowedHandler)
286	}
287
288	// Wrap the sub-router in a handlerFunc to scope the request path for routing.
289	mountHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
290		rctx := RouteContext(r.Context())
291		rctx.RoutePath = mx.nextRoutePath(rctx)
292		handler.ServeHTTP(w, r)
293	})
294
295	if pattern == "" || pattern[len(pattern)-1] != '/' {
296		mx.handle(mALL|mSTUB, pattern, mountHandler)
297		mx.handle(mALL|mSTUB, pattern+"/", mountHandler)
298		pattern += "/"
299	}
300
301	method := mALL
302	subroutes, _ := handler.(Routes)
303	if subroutes != nil {
304		method |= mSTUB
305	}
306	n := mx.handle(method, pattern+"*", mountHandler)
307
308	if subroutes != nil {
309		n.subroutes = subroutes
310	}
311}
312
313// Routes returns a slice of routing information from the tree,
314// useful for traversing available routes of a router.
315func (mx *Mux) Routes() []Route {
316	return mx.tree.routes()
317}
318
319// Middlewares returns a slice of middleware handler functions.
320func (mx *Mux) Middlewares() Middlewares {
321	return mx.middlewares
322}
323
324// Match searches the routing tree for a handler that matches the method/path.
325// It's similar to routing a http request, but without executing the handler
326// thereafter.
327//
328// Note: the *Context state is updated during execution, so manage
329// the state carefully or make a NewRouteContext().
330func (mx *Mux) Match(rctx *Context, method, path string) bool {
331	m, ok := methodMap[method]
332	if !ok {
333		return false
334	}
335
336	node, _, h := mx.tree.FindRoute(rctx, m, path)
337
338	if node != nil && node.subroutes != nil {
339		rctx.RoutePath = mx.nextRoutePath(rctx)
340		return node.subroutes.Match(rctx, method, rctx.RoutePath)
341	}
342
343	return h != nil
344}
345
346// NotFoundHandler returns the default Mux 404 responder whenever a route
347// cannot be found.
348func (mx *Mux) NotFoundHandler() http.HandlerFunc {
349	if mx.notFoundHandler != nil {
350		return mx.notFoundHandler
351	}
352	return http.NotFound
353}
354
355// MethodNotAllowedHandler returns the default Mux 405 responder whenever
356// a method cannot be resolved for a route.
357func (mx *Mux) MethodNotAllowedHandler() http.HandlerFunc {
358	if mx.methodNotAllowedHandler != nil {
359		return mx.methodNotAllowedHandler
360	}
361	return methodNotAllowedHandler
362}
363
364// buildRouteHandler builds the single mux handler that is a chain of the middleware
365// stack, as defined by calls to Use(), and the tree router (Mux) itself. After this
366// point, no other middlewares can be registered on this Mux's stack. But you can still
367// compose additional middlewares via Group()'s or using a chained middleware handler.
368func (mx *Mux) buildRouteHandler() {
369	mx.handler = chain(mx.middlewares, http.HandlerFunc(mx.routeHTTP))
370}
371
372// handle registers a http.Handler in the routing tree for a particular http method
373// and routing pattern.
374func (mx *Mux) handle(method methodTyp, pattern string, handler http.Handler) *node {
375	if len(pattern) == 0 || pattern[0] != '/' {
376		panic(fmt.Sprintf("chi: routing pattern must begin with '/' in '%s'", pattern))
377	}
378
379	// Build the final routing handler for this Mux.
380	if !mx.inline && mx.handler == nil {
381		mx.buildRouteHandler()
382	}
383
384	// Build endpoint handler with inline middlewares for the route
385	var h http.Handler
386	if mx.inline {
387		mx.handler = http.HandlerFunc(mx.routeHTTP)
388		h = Chain(mx.middlewares...).Handler(handler)
389	} else {
390		h = handler
391	}
392
393	// Add the endpoint to the tree and return the node
394	return mx.tree.InsertRoute(method, pattern, h)
395}
396
397// routeHTTP routes a http.Request through the Mux routing tree to serve
398// the matching handler for a particular http method.
399func (mx *Mux) routeHTTP(w http.ResponseWriter, r *http.Request) {
400	// Grab the route context object
401	rctx := r.Context().Value(RouteCtxKey).(*Context)
402
403	// The request routing path
404	routePath := rctx.RoutePath
405	if routePath == "" {
406		if r.URL.RawPath != "" {
407			routePath = r.URL.RawPath
408		} else {
409			routePath = r.URL.Path
410		}
411	}
412
413	// Check if method is supported by chi
414	if rctx.RouteMethod == "" {
415		rctx.RouteMethod = r.Method
416	}
417	method, ok := methodMap[rctx.RouteMethod]
418	if !ok {
419		mx.MethodNotAllowedHandler().ServeHTTP(w, r)
420		return
421	}
422
423	// Find the route
424	if _, _, h := mx.tree.FindRoute(rctx, method, routePath); h != nil {
425		h.ServeHTTP(w, r)
426		return
427	}
428	if rctx.methodNotAllowed {
429		mx.MethodNotAllowedHandler().ServeHTTP(w, r)
430	} else {
431		mx.NotFoundHandler().ServeHTTP(w, r)
432	}
433}
434
435func (mx *Mux) nextRoutePath(rctx *Context) string {
436	routePath := "/"
437	nx := len(rctx.routeParams.Keys) - 1 // index of last param in list
438	if nx >= 0 && rctx.routeParams.Keys[nx] == "*" && len(rctx.routeParams.Values) > nx {
439		routePath += rctx.routeParams.Values[nx]
440	}
441	return routePath
442}
443
444// Recursively update data on child routers.
445func (mx *Mux) updateSubRoutes(fn func(subMux *Mux)) {
446	for _, r := range mx.tree.routes() {
447		subMux, ok := r.SubRoutes.(*Mux)
448		if !ok {
449			continue
450		}
451		fn(subMux)
452	}
453}
454
455// methodNotAllowedHandler is a helper function to respond with a 405,
456// method not allowed.
457func methodNotAllowedHandler(w http.ResponseWriter, r *http.Request) {
458	w.WriteHeader(405)
459	w.Write(nil)
460}
461