1// Copyright 2012 The Gorilla Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mux
6
7import (
8	"errors"
9	"fmt"
10	"net/http"
11	"path"
12	"regexp"
13)
14
15var (
16	// ErrMethodMismatch is returned when the method in the request does not match
17	// the method defined against the route.
18	ErrMethodMismatch = errors.New("method is not allowed")
19	// ErrNotFound is returned when no route match is found.
20	ErrNotFound = errors.New("no matching route was found")
21)
22
23// NewRouter returns a new router instance.
24func NewRouter() *Router {
25	return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
26}
27
28// Router registers routes to be matched and dispatches a handler.
29//
30// It implements the http.Handler interface, so it can be registered to serve
31// requests:
32//
33//     var router = mux.NewRouter()
34//
35//     func main() {
36//         http.Handle("/", router)
37//     }
38//
39// Or, for Google App Engine, register it in a init() function:
40//
41//     func init() {
42//         http.Handle("/", router)
43//     }
44//
45// This will send all incoming requests to the router.
46type Router struct {
47	// Configurable Handler to be used when no route matches.
48	NotFoundHandler http.Handler
49
50	// Configurable Handler to be used when the request method does not match the route.
51	MethodNotAllowedHandler http.Handler
52
53	// Parent route, if this is a subrouter.
54	parent parentRoute
55	// Routes to be matched, in order.
56	routes []*Route
57	// Routes by name for URL building.
58	namedRoutes map[string]*Route
59	// See Router.StrictSlash(). This defines the flag for new routes.
60	strictSlash bool
61	// See Router.SkipClean(). This defines the flag for new routes.
62	skipClean bool
63	// If true, do not clear the request context after handling the request.
64	// This has no effect when go1.7+ is used, since the context is stored
65	// on the request itself.
66	KeepContext bool
67	// see Router.UseEncodedPath(). This defines a flag for all routes.
68	useEncodedPath bool
69	// Slice of middlewares to be called after a match is found
70	middlewares []middleware
71}
72
73// Match attempts to match the given request against the router's registered routes.
74//
75// If the request matches a route of this router or one of its subrouters the Route,
76// Handler, and Vars fields of the the match argument are filled and this function
77// returns true.
78//
79// If the request does not match any of this router's or its subrouters' routes
80// then this function returns false. If available, a reason for the match failure
81// will be filled in the match argument's MatchErr field. If the match failure type
82// (eg: not found) has a registered handler, the handler is assigned to the Handler
83// field of the match argument.
84func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
85	for _, route := range r.routes {
86		if route.Match(req, match) {
87			// Build middleware chain if no error was found
88			if match.MatchErr == nil {
89				for i := len(r.middlewares) - 1; i >= 0; i-- {
90					match.Handler = r.middlewares[i].Middleware(match.Handler)
91				}
92			}
93			return true
94		}
95	}
96
97	if match.MatchErr == ErrMethodMismatch {
98		if r.MethodNotAllowedHandler != nil {
99			match.Handler = r.MethodNotAllowedHandler
100			return true
101		}
102
103		return false
104	}
105
106	// Closest match for a router (includes sub-routers)
107	if r.NotFoundHandler != nil {
108		match.Handler = r.NotFoundHandler
109		match.MatchErr = ErrNotFound
110		return true
111	}
112
113	match.MatchErr = ErrNotFound
114	return false
115}
116
117// ServeHTTP dispatches the handler registered in the matched route.
118//
119// When there is a match, the route variables can be retrieved calling
120// mux.Vars(request).
121func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
122	if !r.skipClean {
123		path := req.URL.Path
124		if r.useEncodedPath {
125			path = req.URL.EscapedPath()
126		}
127		// Clean path to canonical form and redirect.
128		if p := cleanPath(path); p != path {
129
130			// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
131			// This matches with fix in go 1.2 r.c. 4 for same problem.  Go Issue:
132			// http://code.google.com/p/go/issues/detail?id=5252
133			url := *req.URL
134			url.Path = p
135			p = url.String()
136
137			w.Header().Set("Location", p)
138			w.WriteHeader(http.StatusMovedPermanently)
139			return
140		}
141	}
142	var match RouteMatch
143	var handler http.Handler
144	if r.Match(req, &match) {
145		handler = match.Handler
146		req = setVars(req, match.Vars)
147		req = setCurrentRoute(req, match.Route)
148	}
149
150	if handler == nil && match.MatchErr == ErrMethodMismatch {
151		handler = methodNotAllowedHandler()
152	}
153
154	if handler == nil {
155		handler = http.NotFoundHandler()
156	}
157
158	if !r.KeepContext {
159		defer contextClear(req)
160	}
161
162	handler.ServeHTTP(w, req)
163}
164
165// Get returns a route registered with the given name.
166func (r *Router) Get(name string) *Route {
167	return r.getNamedRoutes()[name]
168}
169
170// GetRoute returns a route registered with the given name. This method
171// was renamed to Get() and remains here for backwards compatibility.
172func (r *Router) GetRoute(name string) *Route {
173	return r.getNamedRoutes()[name]
174}
175
176// StrictSlash defines the trailing slash behavior for new routes. The initial
177// value is false.
178//
179// When true, if the route path is "/path/", accessing "/path" will perform a redirect
180// to the former and vice versa. In other words, your application will always
181// see the path as specified in the route.
182//
183// When false, if the route path is "/path", accessing "/path/" will not match
184// this route and vice versa.
185//
186// The re-direct is a HTTP 301 (Moved Permanently). Note that when this is set for
187// routes with a non-idempotent method (e.g. POST, PUT), the subsequent re-directed
188// request will be made as a GET by most clients. Use middleware or client settings
189// to modify this behaviour as needed.
190//
191// Special case: when a route sets a path prefix using the PathPrefix() method,
192// strict slash is ignored for that route because the redirect behavior can't
193// be determined from a prefix alone. However, any subrouters created from that
194// route inherit the original StrictSlash setting.
195func (r *Router) StrictSlash(value bool) *Router {
196	r.strictSlash = value
197	return r
198}
199
200// SkipClean defines the path cleaning behaviour for new routes. The initial
201// value is false. Users should be careful about which routes are not cleaned
202//
203// When true, if the route path is "/path//to", it will remain with the double
204// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
205//
206// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
207// become /fetch/http/xkcd.com/534
208func (r *Router) SkipClean(value bool) *Router {
209	r.skipClean = value
210	return r
211}
212
213// UseEncodedPath tells the router to match the encoded original path
214// to the routes.
215// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
216//
217// If not called, the router will match the unencoded path to the routes.
218// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
219func (r *Router) UseEncodedPath() *Router {
220	r.useEncodedPath = true
221	return r
222}
223
224// ----------------------------------------------------------------------------
225// parentRoute
226// ----------------------------------------------------------------------------
227
228func (r *Router) getBuildScheme() string {
229	if r.parent != nil {
230		return r.parent.getBuildScheme()
231	}
232	return ""
233}
234
235// getNamedRoutes returns the map where named routes are registered.
236func (r *Router) getNamedRoutes() map[string]*Route {
237	if r.namedRoutes == nil {
238		if r.parent != nil {
239			r.namedRoutes = r.parent.getNamedRoutes()
240		} else {
241			r.namedRoutes = make(map[string]*Route)
242		}
243	}
244	return r.namedRoutes
245}
246
247// getRegexpGroup returns regexp definitions from the parent route, if any.
248func (r *Router) getRegexpGroup() *routeRegexpGroup {
249	if r.parent != nil {
250		return r.parent.getRegexpGroup()
251	}
252	return nil
253}
254
255func (r *Router) buildVars(m map[string]string) map[string]string {
256	if r.parent != nil {
257		m = r.parent.buildVars(m)
258	}
259	return m
260}
261
262// ----------------------------------------------------------------------------
263// Route factories
264// ----------------------------------------------------------------------------
265
266// NewRoute registers an empty route.
267func (r *Router) NewRoute() *Route {
268	route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
269	r.routes = append(r.routes, route)
270	return route
271}
272
273// Handle registers a new route with a matcher for the URL path.
274// See Route.Path() and Route.Handler().
275func (r *Router) Handle(path string, handler http.Handler) *Route {
276	return r.NewRoute().Path(path).Handler(handler)
277}
278
279// HandleFunc registers a new route with a matcher for the URL path.
280// See Route.Path() and Route.HandlerFunc().
281func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
282	*http.Request)) *Route {
283	return r.NewRoute().Path(path).HandlerFunc(f)
284}
285
286// Headers registers a new route with a matcher for request header values.
287// See Route.Headers().
288func (r *Router) Headers(pairs ...string) *Route {
289	return r.NewRoute().Headers(pairs...)
290}
291
292// Host registers a new route with a matcher for the URL host.
293// See Route.Host().
294func (r *Router) Host(tpl string) *Route {
295	return r.NewRoute().Host(tpl)
296}
297
298// MatcherFunc registers a new route with a custom matcher function.
299// See Route.MatcherFunc().
300func (r *Router) MatcherFunc(f MatcherFunc) *Route {
301	return r.NewRoute().MatcherFunc(f)
302}
303
304// Methods registers a new route with a matcher for HTTP methods.
305// See Route.Methods().
306func (r *Router) Methods(methods ...string) *Route {
307	return r.NewRoute().Methods(methods...)
308}
309
310// Path registers a new route with a matcher for the URL path.
311// See Route.Path().
312func (r *Router) Path(tpl string) *Route {
313	return r.NewRoute().Path(tpl)
314}
315
316// PathPrefix registers a new route with a matcher for the URL path prefix.
317// See Route.PathPrefix().
318func (r *Router) PathPrefix(tpl string) *Route {
319	return r.NewRoute().PathPrefix(tpl)
320}
321
322// Queries registers a new route with a matcher for URL query values.
323// See Route.Queries().
324func (r *Router) Queries(pairs ...string) *Route {
325	return r.NewRoute().Queries(pairs...)
326}
327
328// Schemes registers a new route with a matcher for URL schemes.
329// See Route.Schemes().
330func (r *Router) Schemes(schemes ...string) *Route {
331	return r.NewRoute().Schemes(schemes...)
332}
333
334// BuildVarsFunc registers a new route with a custom function for modifying
335// route variables before building a URL.
336func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
337	return r.NewRoute().BuildVarsFunc(f)
338}
339
340// Walk walks the router and all its sub-routers, calling walkFn for each route
341// in the tree. The routes are walked in the order they were added. Sub-routers
342// are explored depth-first.
343func (r *Router) Walk(walkFn WalkFunc) error {
344	return r.walk(walkFn, []*Route{})
345}
346
347// SkipRouter is used as a return value from WalkFuncs to indicate that the
348// router that walk is about to descend down to should be skipped.
349var SkipRouter = errors.New("skip this router")
350
351// WalkFunc is the type of the function called for each route visited by Walk.
352// At every invocation, it is given the current route, and the current router,
353// and a list of ancestor routes that lead to the current route.
354type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
355
356func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
357	for _, t := range r.routes {
358		err := walkFn(t, r, ancestors)
359		if err == SkipRouter {
360			continue
361		}
362		if err != nil {
363			return err
364		}
365		for _, sr := range t.matchers {
366			if h, ok := sr.(*Router); ok {
367				ancestors = append(ancestors, t)
368				err := h.walk(walkFn, ancestors)
369				if err != nil {
370					return err
371				}
372				ancestors = ancestors[:len(ancestors)-1]
373			}
374		}
375		if h, ok := t.handler.(*Router); ok {
376			ancestors = append(ancestors, t)
377			err := h.walk(walkFn, ancestors)
378			if err != nil {
379				return err
380			}
381			ancestors = ancestors[:len(ancestors)-1]
382		}
383	}
384	return nil
385}
386
387// ----------------------------------------------------------------------------
388// Context
389// ----------------------------------------------------------------------------
390
391// RouteMatch stores information about a matched route.
392type RouteMatch struct {
393	Route   *Route
394	Handler http.Handler
395	Vars    map[string]string
396
397	// MatchErr is set to appropriate matching error
398	// It is set to ErrMethodMismatch if there is a mismatch in
399	// the request method and route method
400	MatchErr error
401}
402
403type contextKey int
404
405const (
406	varsKey contextKey = iota
407	routeKey
408)
409
410// Vars returns the route variables for the current request, if any.
411func Vars(r *http.Request) map[string]string {
412	if rv := contextGet(r, varsKey); rv != nil {
413		return rv.(map[string]string)
414	}
415	return nil
416}
417
418// CurrentRoute returns the matched route for the current request, if any.
419// This only works when called inside the handler of the matched route
420// because the matched route is stored in the request context which is cleared
421// after the handler returns, unless the KeepContext option is set on the
422// Router.
423func CurrentRoute(r *http.Request) *Route {
424	if rv := contextGet(r, routeKey); rv != nil {
425		return rv.(*Route)
426	}
427	return nil
428}
429
430func setVars(r *http.Request, val interface{}) *http.Request {
431	return contextSet(r, varsKey, val)
432}
433
434func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
435	return contextSet(r, routeKey, val)
436}
437
438// ----------------------------------------------------------------------------
439// Helpers
440// ----------------------------------------------------------------------------
441
442// cleanPath returns the canonical path for p, eliminating . and .. elements.
443// Borrowed from the net/http package.
444func cleanPath(p string) string {
445	if p == "" {
446		return "/"
447	}
448	if p[0] != '/' {
449		p = "/" + p
450	}
451	np := path.Clean(p)
452	// path.Clean removes trailing slash except for root;
453	// put the trailing slash back if necessary.
454	if p[len(p)-1] == '/' && np != "/" {
455		np += "/"
456	}
457
458	return np
459}
460
461// uniqueVars returns an error if two slices contain duplicated strings.
462func uniqueVars(s1, s2 []string) error {
463	for _, v1 := range s1 {
464		for _, v2 := range s2 {
465			if v1 == v2 {
466				return fmt.Errorf("mux: duplicated route variable %q", v2)
467			}
468		}
469	}
470	return nil
471}
472
473// checkPairs returns the count of strings passed in, and an error if
474// the count is not an even number.
475func checkPairs(pairs ...string) (int, error) {
476	length := len(pairs)
477	if length%2 != 0 {
478		return length, fmt.Errorf(
479			"mux: number of parameters must be multiple of 2, got %v", pairs)
480	}
481	return length, nil
482}
483
484// mapFromPairsToString converts variadic string parameters to a
485// string to string map.
486func mapFromPairsToString(pairs ...string) (map[string]string, error) {
487	length, err := checkPairs(pairs...)
488	if err != nil {
489		return nil, err
490	}
491	m := make(map[string]string, length/2)
492	for i := 0; i < length; i += 2 {
493		m[pairs[i]] = pairs[i+1]
494	}
495	return m, nil
496}
497
498// mapFromPairsToRegex converts variadic string parameters to a
499// string to regex map.
500func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
501	length, err := checkPairs(pairs...)
502	if err != nil {
503		return nil, err
504	}
505	m := make(map[string]*regexp.Regexp, length/2)
506	for i := 0; i < length; i += 2 {
507		regex, err := regexp.Compile(pairs[i+1])
508		if err != nil {
509			return nil, err
510		}
511		m[pairs[i]] = regex
512	}
513	return m, nil
514}
515
516// matchInArray returns true if the given string value is in the array.
517func matchInArray(arr []string, value string) bool {
518	for _, v := range arr {
519		if v == value {
520			return true
521		}
522	}
523	return false
524}
525
526// matchMapWithString returns true if the given key/value pairs exist in a given map.
527func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
528	for k, v := range toCheck {
529		// Check if key exists.
530		if canonicalKey {
531			k = http.CanonicalHeaderKey(k)
532		}
533		if values := toMatch[k]; values == nil {
534			return false
535		} else if v != "" {
536			// If value was defined as an empty string we only check that the
537			// key exists. Otherwise we also check for equality.
538			valueExists := false
539			for _, value := range values {
540				if v == value {
541					valueExists = true
542					break
543				}
544			}
545			if !valueExists {
546				return false
547			}
548		}
549	}
550	return true
551}
552
553// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
554// the given regex
555func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
556	for k, v := range toCheck {
557		// Check if key exists.
558		if canonicalKey {
559			k = http.CanonicalHeaderKey(k)
560		}
561		if values := toMatch[k]; values == nil {
562			return false
563		} else if v != nil {
564			// If value was defined as an empty string we only check that the
565			// key exists. Otherwise we also check for equality.
566			valueExists := false
567			for _, value := range values {
568				if v.MatchString(value) {
569					valueExists = true
570					break
571				}
572			}
573			if !valueExists {
574				return false
575			}
576		}
577	}
578	return true
579}
580
581// methodNotAllowed replies to the request with an HTTP status code 405.
582func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
583	w.WriteHeader(http.StatusMethodNotAllowed)
584}
585
586// methodNotAllowedHandler returns a simple request handler
587// that replies to each request with a status code 405.
588func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }
589