1// Copyright 2012 The Gorilla Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mux
6
7import (
8	"errors"
9	"fmt"
10	"net/http"
11	"path"
12	"regexp"
13	"strings"
14)
15
16// NewRouter returns a new router instance.
17func NewRouter() *Router {
18	return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
19}
20
21// Router registers routes to be matched and dispatches a handler.
22//
23// It implements the http.Handler interface, so it can be registered to serve
24// requests:
25//
26//     var router = mux.NewRouter()
27//
28//     func main() {
29//         http.Handle("/", router)
30//     }
31//
32// Or, for Google App Engine, register it in a init() function:
33//
34//     func init() {
35//         http.Handle("/", router)
36//     }
37//
38// This will send all incoming requests to the router.
39type Router struct {
40	// Configurable Handler to be used when no route matches.
41	NotFoundHandler http.Handler
42	// Parent route, if this is a subrouter.
43	parent parentRoute
44	// Routes to be matched, in order.
45	routes []*Route
46	// Routes by name for URL building.
47	namedRoutes map[string]*Route
48	// See Router.StrictSlash(). This defines the flag for new routes.
49	strictSlash bool
50	// See Router.SkipClean(). This defines the flag for new routes.
51	skipClean bool
52	// If true, do not clear the request context after handling the request.
53	// This has no effect when go1.7+ is used, since the context is stored
54	// on the request itself.
55	KeepContext bool
56	// see Router.UseEncodedPath(). This defines a flag for all routes.
57	useEncodedPath bool
58}
59
60// Match matches registered routes against the request.
61func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
62	for _, route := range r.routes {
63		if route.Match(req, match) {
64			return true
65		}
66	}
67
68	// Closest match for a router (includes sub-routers)
69	if r.NotFoundHandler != nil {
70		match.Handler = r.NotFoundHandler
71		return true
72	}
73	return false
74}
75
76// ServeHTTP dispatches the handler registered in the matched route.
77//
78// When there is a match, the route variables can be retrieved calling
79// mux.Vars(request).
80func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
81	if !r.skipClean {
82		path := req.URL.Path
83		if r.useEncodedPath {
84			path = getPath(req)
85		}
86		// Clean path to canonical form and redirect.
87		if p := cleanPath(path); p != path {
88
89			// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
90			// This matches with fix in go 1.2 r.c. 4 for same problem.  Go Issue:
91			// http://code.google.com/p/go/issues/detail?id=5252
92			url := *req.URL
93			url.Path = p
94			p = url.String()
95
96			w.Header().Set("Location", p)
97			w.WriteHeader(http.StatusMovedPermanently)
98			return
99		}
100	}
101	var match RouteMatch
102	var handler http.Handler
103	if r.Match(req, &match) {
104		handler = match.Handler
105		req = setVars(req, match.Vars)
106		req = setCurrentRoute(req, match.Route)
107	}
108	if handler == nil {
109		handler = http.NotFoundHandler()
110	}
111	if !r.KeepContext {
112		defer contextClear(req)
113	}
114	handler.ServeHTTP(w, req)
115}
116
117// Get returns a route registered with the given name.
118func (r *Router) Get(name string) *Route {
119	return r.getNamedRoutes()[name]
120}
121
122// GetRoute returns a route registered with the given name. This method
123// was renamed to Get() and remains here for backwards compatibility.
124func (r *Router) GetRoute(name string) *Route {
125	return r.getNamedRoutes()[name]
126}
127
128// StrictSlash defines the trailing slash behavior for new routes. The initial
129// value is false.
130//
131// When true, if the route path is "/path/", accessing "/path" will redirect
132// to the former and vice versa. In other words, your application will always
133// see the path as specified in the route.
134//
135// When false, if the route path is "/path", accessing "/path/" will not match
136// this route and vice versa.
137//
138// Special case: when a route sets a path prefix using the PathPrefix() method,
139// strict slash is ignored for that route because the redirect behavior can't
140// be determined from a prefix alone. However, any subrouters created from that
141// route inherit the original StrictSlash setting.
142func (r *Router) StrictSlash(value bool) *Router {
143	r.strictSlash = value
144	return r
145}
146
147// SkipClean defines the path cleaning behaviour for new routes. The initial
148// value is false. Users should be careful about which routes are not cleaned
149//
150// When true, if the route path is "/path//to", it will remain with the double
151// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
152//
153// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
154// become /fetch/http/xkcd.com/534
155func (r *Router) SkipClean(value bool) *Router {
156	r.skipClean = value
157	return r
158}
159
160// UseEncodedPath tells the router to match the encoded original path
161// to the routes.
162// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
163// This behavior has the drawback of needing to match routes against
164// r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix)
165// to r.URL.Path will not affect routing when this flag is on and thus may
166// induce unintended behavior.
167//
168// If not called, the router will match the unencoded path to the routes.
169// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
170func (r *Router) UseEncodedPath() *Router {
171	r.useEncodedPath = true
172	return r
173}
174
175// ----------------------------------------------------------------------------
176// parentRoute
177// ----------------------------------------------------------------------------
178
179func (r *Router) getBuildScheme() string {
180	if r.parent != nil {
181		return r.parent.getBuildScheme()
182	}
183	return ""
184}
185
186// getNamedRoutes returns the map where named routes are registered.
187func (r *Router) getNamedRoutes() map[string]*Route {
188	if r.namedRoutes == nil {
189		if r.parent != nil {
190			r.namedRoutes = r.parent.getNamedRoutes()
191		} else {
192			r.namedRoutes = make(map[string]*Route)
193		}
194	}
195	return r.namedRoutes
196}
197
198// getRegexpGroup returns regexp definitions from the parent route, if any.
199func (r *Router) getRegexpGroup() *routeRegexpGroup {
200	if r.parent != nil {
201		return r.parent.getRegexpGroup()
202	}
203	return nil
204}
205
206func (r *Router) buildVars(m map[string]string) map[string]string {
207	if r.parent != nil {
208		m = r.parent.buildVars(m)
209	}
210	return m
211}
212
213// ----------------------------------------------------------------------------
214// Route factories
215// ----------------------------------------------------------------------------
216
217// NewRoute registers an empty route.
218func (r *Router) NewRoute() *Route {
219	route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
220	r.routes = append(r.routes, route)
221	return route
222}
223
224// Handle registers a new route with a matcher for the URL path.
225// See Route.Path() and Route.Handler().
226func (r *Router) Handle(path string, handler http.Handler) *Route {
227	return r.NewRoute().Path(path).Handler(handler)
228}
229
230// HandleFunc registers a new route with a matcher for the URL path.
231// See Route.Path() and Route.HandlerFunc().
232func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
233	*http.Request)) *Route {
234	return r.NewRoute().Path(path).HandlerFunc(f)
235}
236
237// Headers registers a new route with a matcher for request header values.
238// See Route.Headers().
239func (r *Router) Headers(pairs ...string) *Route {
240	return r.NewRoute().Headers(pairs...)
241}
242
243// Host registers a new route with a matcher for the URL host.
244// See Route.Host().
245func (r *Router) Host(tpl string) *Route {
246	return r.NewRoute().Host(tpl)
247}
248
249// MatcherFunc registers a new route with a custom matcher function.
250// See Route.MatcherFunc().
251func (r *Router) MatcherFunc(f MatcherFunc) *Route {
252	return r.NewRoute().MatcherFunc(f)
253}
254
255// Methods registers a new route with a matcher for HTTP methods.
256// See Route.Methods().
257func (r *Router) Methods(methods ...string) *Route {
258	return r.NewRoute().Methods(methods...)
259}
260
261// Path registers a new route with a matcher for the URL path.
262// See Route.Path().
263func (r *Router) Path(tpl string) *Route {
264	return r.NewRoute().Path(tpl)
265}
266
267// PathPrefix registers a new route with a matcher for the URL path prefix.
268// See Route.PathPrefix().
269func (r *Router) PathPrefix(tpl string) *Route {
270	return r.NewRoute().PathPrefix(tpl)
271}
272
273// Queries registers a new route with a matcher for URL query values.
274// See Route.Queries().
275func (r *Router) Queries(pairs ...string) *Route {
276	return r.NewRoute().Queries(pairs...)
277}
278
279// Schemes registers a new route with a matcher for URL schemes.
280// See Route.Schemes().
281func (r *Router) Schemes(schemes ...string) *Route {
282	return r.NewRoute().Schemes(schemes...)
283}
284
285// BuildVarsFunc registers a new route with a custom function for modifying
286// route variables before building a URL.
287func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
288	return r.NewRoute().BuildVarsFunc(f)
289}
290
291// Walk walks the router and all its sub-routers, calling walkFn for each route
292// in the tree. The routes are walked in the order they were added. Sub-routers
293// are explored depth-first.
294func (r *Router) Walk(walkFn WalkFunc) error {
295	return r.walk(walkFn, []*Route{})
296}
297
298// SkipRouter is used as a return value from WalkFuncs to indicate that the
299// router that walk is about to descend down to should be skipped.
300var SkipRouter = errors.New("skip this router")
301
302// WalkFunc is the type of the function called for each route visited by Walk.
303// At every invocation, it is given the current route, and the current router,
304// and a list of ancestor routes that lead to the current route.
305type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
306
307func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
308	for _, t := range r.routes {
309		err := walkFn(t, r, ancestors)
310		if err == SkipRouter {
311			continue
312		}
313		if err != nil {
314			return err
315		}
316		for _, sr := range t.matchers {
317			if h, ok := sr.(*Router); ok {
318				ancestors = append(ancestors, t)
319				err := h.walk(walkFn, ancestors)
320				if err != nil {
321					return err
322				}
323				ancestors = ancestors[:len(ancestors)-1]
324			}
325		}
326		if h, ok := t.handler.(*Router); ok {
327			ancestors = append(ancestors, t)
328			err := h.walk(walkFn, ancestors)
329			if err != nil {
330				return err
331			}
332			ancestors = ancestors[:len(ancestors)-1]
333		}
334	}
335	return nil
336}
337
338// ----------------------------------------------------------------------------
339// Context
340// ----------------------------------------------------------------------------
341
342// RouteMatch stores information about a matched route.
343type RouteMatch struct {
344	Route   *Route
345	Handler http.Handler
346	Vars    map[string]string
347}
348
349type contextKey int
350
351const (
352	varsKey contextKey = iota
353	routeKey
354)
355
356// Vars returns the route variables for the current request, if any.
357func Vars(r *http.Request) map[string]string {
358	if rv := contextGet(r, varsKey); rv != nil {
359		return rv.(map[string]string)
360	}
361	return nil
362}
363
364// CurrentRoute returns the matched route for the current request, if any.
365// This only works when called inside the handler of the matched route
366// because the matched route is stored in the request context which is cleared
367// after the handler returns, unless the KeepContext option is set on the
368// Router.
369func CurrentRoute(r *http.Request) *Route {
370	if rv := contextGet(r, routeKey); rv != nil {
371		return rv.(*Route)
372	}
373	return nil
374}
375
376func setVars(r *http.Request, val interface{}) *http.Request {
377	return contextSet(r, varsKey, val)
378}
379
380func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
381	return contextSet(r, routeKey, val)
382}
383
384// ----------------------------------------------------------------------------
385// Helpers
386// ----------------------------------------------------------------------------
387
388// getPath returns the escaped path if possible; doing what URL.EscapedPath()
389// which was added in go1.5 does
390func getPath(req *http.Request) string {
391	if req.RequestURI != "" {
392		// Extract the path from RequestURI (which is escaped unlike URL.Path)
393		// as detailed here as detailed in https://golang.org/pkg/net/url/#URL
394		// for < 1.5 server side workaround
395		// http://localhost/path/here?v=1 -> /path/here
396		path := req.RequestURI
397		path = strings.TrimPrefix(path, req.URL.Scheme+`://`)
398		path = strings.TrimPrefix(path, req.URL.Host)
399		if i := strings.LastIndex(path, "?"); i > -1 {
400			path = path[:i]
401		}
402		if i := strings.LastIndex(path, "#"); i > -1 {
403			path = path[:i]
404		}
405		return path
406	}
407	return req.URL.Path
408}
409
410// cleanPath returns the canonical path for p, eliminating . and .. elements.
411// Borrowed from the net/http package.
412func cleanPath(p string) string {
413	if p == "" {
414		return "/"
415	}
416	if p[0] != '/' {
417		p = "/" + p
418	}
419	np := path.Clean(p)
420	// path.Clean removes trailing slash except for root;
421	// put the trailing slash back if necessary.
422	if p[len(p)-1] == '/' && np != "/" {
423		np += "/"
424	}
425
426	return np
427}
428
429// uniqueVars returns an error if two slices contain duplicated strings.
430func uniqueVars(s1, s2 []string) error {
431	for _, v1 := range s1 {
432		for _, v2 := range s2 {
433			if v1 == v2 {
434				return fmt.Errorf("mux: duplicated route variable %q", v2)
435			}
436		}
437	}
438	return nil
439}
440
441// checkPairs returns the count of strings passed in, and an error if
442// the count is not an even number.
443func checkPairs(pairs ...string) (int, error) {
444	length := len(pairs)
445	if length%2 != 0 {
446		return length, fmt.Errorf(
447			"mux: number of parameters must be multiple of 2, got %v", pairs)
448	}
449	return length, nil
450}
451
452// mapFromPairsToString converts variadic string parameters to a
453// string to string map.
454func mapFromPairsToString(pairs ...string) (map[string]string, error) {
455	length, err := checkPairs(pairs...)
456	if err != nil {
457		return nil, err
458	}
459	m := make(map[string]string, length/2)
460	for i := 0; i < length; i += 2 {
461		m[pairs[i]] = pairs[i+1]
462	}
463	return m, nil
464}
465
466// mapFromPairsToRegex converts variadic string parameters to a
467// string to regex map.
468func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
469	length, err := checkPairs(pairs...)
470	if err != nil {
471		return nil, err
472	}
473	m := make(map[string]*regexp.Regexp, length/2)
474	for i := 0; i < length; i += 2 {
475		regex, err := regexp.Compile(pairs[i+1])
476		if err != nil {
477			return nil, err
478		}
479		m[pairs[i]] = regex
480	}
481	return m, nil
482}
483
484// matchInArray returns true if the given string value is in the array.
485func matchInArray(arr []string, value string) bool {
486	for _, v := range arr {
487		if v == value {
488			return true
489		}
490	}
491	return false
492}
493
494// matchMapWithString returns true if the given key/value pairs exist in a given map.
495func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
496	for k, v := range toCheck {
497		// Check if key exists.
498		if canonicalKey {
499			k = http.CanonicalHeaderKey(k)
500		}
501		if values := toMatch[k]; values == nil {
502			return false
503		} else if v != "" {
504			// If value was defined as an empty string we only check that the
505			// key exists. Otherwise we also check for equality.
506			valueExists := false
507			for _, value := range values {
508				if v == value {
509					valueExists = true
510					break
511				}
512			}
513			if !valueExists {
514				return false
515			}
516		}
517	}
518	return true
519}
520
521// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
522// the given regex
523func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
524	for k, v := range toCheck {
525		// Check if key exists.
526		if canonicalKey {
527			k = http.CanonicalHeaderKey(k)
528		}
529		if values := toMatch[k]; values == nil {
530			return false
531		} else if v != nil {
532			// If value was defined as an empty string we only check that the
533			// key exists. Otherwise we also check for equality.
534			valueExists := false
535			for _, value := range values {
536				if v.MatchString(value) {
537					valueExists = true
538					break
539				}
540			}
541			if !valueExists {
542				return false
543			}
544		}
545	}
546	return true
547}
548