1// Copyright 2012 The Gorilla Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mux
6
7import (
8	"errors"
9	"fmt"
10	"net/http"
11	"path"
12	"regexp"
13)
14
15var (
16	ErrMethodMismatch = errors.New("method is not allowed")
17	ErrNotFound       = errors.New("no matching route was found")
18)
19
20// NewRouter returns a new router instance.
21func NewRouter() *Router {
22	return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
23}
24
25// Router registers routes to be matched and dispatches a handler.
26//
27// It implements the http.Handler interface, so it can be registered to serve
28// requests:
29//
30//     var router = mux.NewRouter()
31//
32//     func main() {
33//         http.Handle("/", router)
34//     }
35//
36// Or, for Google App Engine, register it in a init() function:
37//
38//     func init() {
39//         http.Handle("/", router)
40//     }
41//
42// This will send all incoming requests to the router.
43type Router struct {
44	// Configurable Handler to be used when no route matches.
45	NotFoundHandler http.Handler
46
47	// Configurable Handler to be used when the request method does not match the route.
48	MethodNotAllowedHandler http.Handler
49
50	// Parent route, if this is a subrouter.
51	parent parentRoute
52	// Routes to be matched, in order.
53	routes []*Route
54	// Routes by name for URL building.
55	namedRoutes map[string]*Route
56	// See Router.StrictSlash(). This defines the flag for new routes.
57	strictSlash bool
58	// See Router.SkipClean(). This defines the flag for new routes.
59	skipClean bool
60	// If true, do not clear the request context after handling the request.
61	// This has no effect when go1.7+ is used, since the context is stored
62	// on the request itself.
63	KeepContext bool
64	// see Router.UseEncodedPath(). This defines a flag for all routes.
65	useEncodedPath bool
66}
67
68// Match attempts to match the given request against the router's registered routes.
69//
70// If the request matches a route of this router or one of its subrouters the Route,
71// Handler, and Vars fields of the the match argument are filled and this function
72// returns true.
73//
74// If the request does not match any of this router's or its subrouters' routes
75// then this function returns false. If available, a reason for the match failure
76// will be filled in the match argument's MatchErr field. If the match failure type
77// (eg: not found) has a registered handler, the handler is assigned to the Handler
78// field of the match argument.
79func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
80	for _, route := range r.routes {
81		if route.Match(req, match) {
82			return true
83		}
84	}
85
86	if match.MatchErr == ErrMethodMismatch {
87		if r.MethodNotAllowedHandler != nil {
88			match.Handler = r.MethodNotAllowedHandler
89			return true
90		} else {
91			return false
92		}
93	}
94
95	// Closest match for a router (includes sub-routers)
96	if r.NotFoundHandler != nil {
97		match.Handler = r.NotFoundHandler
98		match.MatchErr = ErrNotFound
99		return true
100	}
101
102	match.MatchErr = ErrNotFound
103	return false
104}
105
106// ServeHTTP dispatches the handler registered in the matched route.
107//
108// When there is a match, the route variables can be retrieved calling
109// mux.Vars(request).
110func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
111	if !r.skipClean {
112		path := req.URL.Path
113		if r.useEncodedPath {
114			path = req.URL.EscapedPath()
115		}
116		// Clean path to canonical form and redirect.
117		if p := cleanPath(path); p != path {
118
119			// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
120			// This matches with fix in go 1.2 r.c. 4 for same problem.  Go Issue:
121			// http://code.google.com/p/go/issues/detail?id=5252
122			url := *req.URL
123			url.Path = p
124			p = url.String()
125
126			w.Header().Set("Location", p)
127			w.WriteHeader(http.StatusMovedPermanently)
128			return
129		}
130	}
131	var match RouteMatch
132	var handler http.Handler
133	if r.Match(req, &match) {
134		handler = match.Handler
135		req = setVars(req, match.Vars)
136		req = setCurrentRoute(req, match.Route)
137	}
138
139	if handler == nil && match.MatchErr == ErrMethodMismatch {
140		handler = methodNotAllowedHandler()
141	}
142
143	if handler == nil {
144		handler = http.NotFoundHandler()
145	}
146
147	if !r.KeepContext {
148		defer contextClear(req)
149	}
150	handler.ServeHTTP(w, req)
151}
152
153// Get returns a route registered with the given name.
154func (r *Router) Get(name string) *Route {
155	return r.getNamedRoutes()[name]
156}
157
158// GetRoute returns a route registered with the given name. This method
159// was renamed to Get() and remains here for backwards compatibility.
160func (r *Router) GetRoute(name string) *Route {
161	return r.getNamedRoutes()[name]
162}
163
164// StrictSlash defines the trailing slash behavior for new routes. The initial
165// value is false.
166//
167// When true, if the route path is "/path/", accessing "/path" will redirect
168// to the former and vice versa. In other words, your application will always
169// see the path as specified in the route.
170//
171// When false, if the route path is "/path", accessing "/path/" will not match
172// this route and vice versa.
173//
174// Special case: when a route sets a path prefix using the PathPrefix() method,
175// strict slash is ignored for that route because the redirect behavior can't
176// be determined from a prefix alone. However, any subrouters created from that
177// route inherit the original StrictSlash setting.
178func (r *Router) StrictSlash(value bool) *Router {
179	r.strictSlash = value
180	return r
181}
182
183// SkipClean defines the path cleaning behaviour for new routes. The initial
184// value is false. Users should be careful about which routes are not cleaned
185//
186// When true, if the route path is "/path//to", it will remain with the double
187// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
188//
189// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
190// become /fetch/http/xkcd.com/534
191func (r *Router) SkipClean(value bool) *Router {
192	r.skipClean = value
193	return r
194}
195
196// UseEncodedPath tells the router to match the encoded original path
197// to the routes.
198// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
199// This behavior has the drawback of needing to match routes against
200// r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix)
201// to r.URL.Path will not affect routing when this flag is on and thus may
202// induce unintended behavior.
203//
204// If not called, the router will match the unencoded path to the routes.
205// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
206func (r *Router) UseEncodedPath() *Router {
207	r.useEncodedPath = true
208	return r
209}
210
211// ----------------------------------------------------------------------------
212// parentRoute
213// ----------------------------------------------------------------------------
214
215func (r *Router) getBuildScheme() string {
216	if r.parent != nil {
217		return r.parent.getBuildScheme()
218	}
219	return ""
220}
221
222// getNamedRoutes returns the map where named routes are registered.
223func (r *Router) getNamedRoutes() map[string]*Route {
224	if r.namedRoutes == nil {
225		if r.parent != nil {
226			r.namedRoutes = r.parent.getNamedRoutes()
227		} else {
228			r.namedRoutes = make(map[string]*Route)
229		}
230	}
231	return r.namedRoutes
232}
233
234// getRegexpGroup returns regexp definitions from the parent route, if any.
235func (r *Router) getRegexpGroup() *routeRegexpGroup {
236	if r.parent != nil {
237		return r.parent.getRegexpGroup()
238	}
239	return nil
240}
241
242func (r *Router) buildVars(m map[string]string) map[string]string {
243	if r.parent != nil {
244		m = r.parent.buildVars(m)
245	}
246	return m
247}
248
249// ----------------------------------------------------------------------------
250// Route factories
251// ----------------------------------------------------------------------------
252
253// NewRoute registers an empty route.
254func (r *Router) NewRoute() *Route {
255	route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
256	r.routes = append(r.routes, route)
257	return route
258}
259
260// Handle registers a new route with a matcher for the URL path.
261// See Route.Path() and Route.Handler().
262func (r *Router) Handle(path string, handler http.Handler) *Route {
263	return r.NewRoute().Path(path).Handler(handler)
264}
265
266// HandleFunc registers a new route with a matcher for the URL path.
267// See Route.Path() and Route.HandlerFunc().
268func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
269	*http.Request)) *Route {
270	return r.NewRoute().Path(path).HandlerFunc(f)
271}
272
273// Headers registers a new route with a matcher for request header values.
274// See Route.Headers().
275func (r *Router) Headers(pairs ...string) *Route {
276	return r.NewRoute().Headers(pairs...)
277}
278
279// Host registers a new route with a matcher for the URL host.
280// See Route.Host().
281func (r *Router) Host(tpl string) *Route {
282	return r.NewRoute().Host(tpl)
283}
284
285// MatcherFunc registers a new route with a custom matcher function.
286// See Route.MatcherFunc().
287func (r *Router) MatcherFunc(f MatcherFunc) *Route {
288	return r.NewRoute().MatcherFunc(f)
289}
290
291// Methods registers a new route with a matcher for HTTP methods.
292// See Route.Methods().
293func (r *Router) Methods(methods ...string) *Route {
294	return r.NewRoute().Methods(methods...)
295}
296
297// Path registers a new route with a matcher for the URL path.
298// See Route.Path().
299func (r *Router) Path(tpl string) *Route {
300	return r.NewRoute().Path(tpl)
301}
302
303// PathPrefix registers a new route with a matcher for the URL path prefix.
304// See Route.PathPrefix().
305func (r *Router) PathPrefix(tpl string) *Route {
306	return r.NewRoute().PathPrefix(tpl)
307}
308
309// Queries registers a new route with a matcher for URL query values.
310// See Route.Queries().
311func (r *Router) Queries(pairs ...string) *Route {
312	return r.NewRoute().Queries(pairs...)
313}
314
315// Schemes registers a new route with a matcher for URL schemes.
316// See Route.Schemes().
317func (r *Router) Schemes(schemes ...string) *Route {
318	return r.NewRoute().Schemes(schemes...)
319}
320
321// BuildVarsFunc registers a new route with a custom function for modifying
322// route variables before building a URL.
323func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
324	return r.NewRoute().BuildVarsFunc(f)
325}
326
327// Walk walks the router and all its sub-routers, calling walkFn for each route
328// in the tree. The routes are walked in the order they were added. Sub-routers
329// are explored depth-first.
330func (r *Router) Walk(walkFn WalkFunc) error {
331	return r.walk(walkFn, []*Route{})
332}
333
334// SkipRouter is used as a return value from WalkFuncs to indicate that the
335// router that walk is about to descend down to should be skipped.
336var SkipRouter = errors.New("skip this router")
337
338// WalkFunc is the type of the function called for each route visited by Walk.
339// At every invocation, it is given the current route, and the current router,
340// and a list of ancestor routes that lead to the current route.
341type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
342
343func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
344	for _, t := range r.routes {
345		err := walkFn(t, r, ancestors)
346		if err == SkipRouter {
347			continue
348		}
349		if err != nil {
350			return err
351		}
352		for _, sr := range t.matchers {
353			if h, ok := sr.(*Router); ok {
354				ancestors = append(ancestors, t)
355				err := h.walk(walkFn, ancestors)
356				if err != nil {
357					return err
358				}
359				ancestors = ancestors[:len(ancestors)-1]
360			}
361		}
362		if h, ok := t.handler.(*Router); ok {
363			ancestors = append(ancestors, t)
364			err := h.walk(walkFn, ancestors)
365			if err != nil {
366				return err
367			}
368			ancestors = ancestors[:len(ancestors)-1]
369		}
370	}
371	return nil
372}
373
374// ----------------------------------------------------------------------------
375// Context
376// ----------------------------------------------------------------------------
377
378// RouteMatch stores information about a matched route.
379type RouteMatch struct {
380	Route   *Route
381	Handler http.Handler
382	Vars    map[string]string
383
384	// MatchErr is set to appropriate matching error
385	// It is set to ErrMethodMismatch if there is a mismatch in
386	// the request method and route method
387	MatchErr error
388}
389
390type contextKey int
391
392const (
393	varsKey contextKey = iota
394	routeKey
395)
396
397// Vars returns the route variables for the current request, if any.
398func Vars(r *http.Request) map[string]string {
399	if rv := contextGet(r, varsKey); rv != nil {
400		return rv.(map[string]string)
401	}
402	return nil
403}
404
405// CurrentRoute returns the matched route for the current request, if any.
406// This only works when called inside the handler of the matched route
407// because the matched route is stored in the request context which is cleared
408// after the handler returns, unless the KeepContext option is set on the
409// Router.
410func CurrentRoute(r *http.Request) *Route {
411	if rv := contextGet(r, routeKey); rv != nil {
412		return rv.(*Route)
413	}
414	return nil
415}
416
417func setVars(r *http.Request, val interface{}) *http.Request {
418	return contextSet(r, varsKey, val)
419}
420
421func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
422	return contextSet(r, routeKey, val)
423}
424
425// ----------------------------------------------------------------------------
426// Helpers
427// ----------------------------------------------------------------------------
428
429// cleanPath returns the canonical path for p, eliminating . and .. elements.
430// Borrowed from the net/http package.
431func cleanPath(p string) string {
432	if p == "" {
433		return "/"
434	}
435	if p[0] != '/' {
436		p = "/" + p
437	}
438	np := path.Clean(p)
439	// path.Clean removes trailing slash except for root;
440	// put the trailing slash back if necessary.
441	if p[len(p)-1] == '/' && np != "/" {
442		np += "/"
443	}
444
445	return np
446}
447
448// uniqueVars returns an error if two slices contain duplicated strings.
449func uniqueVars(s1, s2 []string) error {
450	for _, v1 := range s1 {
451		for _, v2 := range s2 {
452			if v1 == v2 {
453				return fmt.Errorf("mux: duplicated route variable %q", v2)
454			}
455		}
456	}
457	return nil
458}
459
460// checkPairs returns the count of strings passed in, and an error if
461// the count is not an even number.
462func checkPairs(pairs ...string) (int, error) {
463	length := len(pairs)
464	if length%2 != 0 {
465		return length, fmt.Errorf(
466			"mux: number of parameters must be multiple of 2, got %v", pairs)
467	}
468	return length, nil
469}
470
471// mapFromPairsToString converts variadic string parameters to a
472// string to string map.
473func mapFromPairsToString(pairs ...string) (map[string]string, error) {
474	length, err := checkPairs(pairs...)
475	if err != nil {
476		return nil, err
477	}
478	m := make(map[string]string, length/2)
479	for i := 0; i < length; i += 2 {
480		m[pairs[i]] = pairs[i+1]
481	}
482	return m, nil
483}
484
485// mapFromPairsToRegex converts variadic string parameters to a
486// string to regex map.
487func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
488	length, err := checkPairs(pairs...)
489	if err != nil {
490		return nil, err
491	}
492	m := make(map[string]*regexp.Regexp, length/2)
493	for i := 0; i < length; i += 2 {
494		regex, err := regexp.Compile(pairs[i+1])
495		if err != nil {
496			return nil, err
497		}
498		m[pairs[i]] = regex
499	}
500	return m, nil
501}
502
503// matchInArray returns true if the given string value is in the array.
504func matchInArray(arr []string, value string) bool {
505	for _, v := range arr {
506		if v == value {
507			return true
508		}
509	}
510	return false
511}
512
513// matchMapWithString returns true if the given key/value pairs exist in a given map.
514func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
515	for k, v := range toCheck {
516		// Check if key exists.
517		if canonicalKey {
518			k = http.CanonicalHeaderKey(k)
519		}
520		if values := toMatch[k]; values == nil {
521			return false
522		} else if v != "" {
523			// If value was defined as an empty string we only check that the
524			// key exists. Otherwise we also check for equality.
525			valueExists := false
526			for _, value := range values {
527				if v == value {
528					valueExists = true
529					break
530				}
531			}
532			if !valueExists {
533				return false
534			}
535		}
536	}
537	return true
538}
539
540// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
541// the given regex
542func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
543	for k, v := range toCheck {
544		// Check if key exists.
545		if canonicalKey {
546			k = http.CanonicalHeaderKey(k)
547		}
548		if values := toMatch[k]; values == nil {
549			return false
550		} else if v != nil {
551			// If value was defined as an empty string we only check that the
552			// key exists. Otherwise we also check for equality.
553			valueExists := false
554			for _, value := range values {
555				if v.MatchString(value) {
556					valueExists = true
557					break
558				}
559			}
560			if !valueExists {
561				return false
562			}
563		}
564	}
565	return true
566}
567
568// methodNotAllowed replies to the request with an HTTP status code 405.
569func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
570	w.WriteHeader(http.StatusMethodNotAllowed)
571}
572
573// methodNotAllowedHandler returns a simple request handler
574// that replies to each request with a status code 405.
575func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }
576