1package chi
2
3// Radix tree implementation below is a based on the original work by
4// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go
5// (MIT licensed). It's been heavily modified for use as a HTTP routing tree.
6
7import (
8	"fmt"
9	"net/http"
10	"regexp"
11	"sort"
12	"strconv"
13	"strings"
14)
15
16type methodTyp uint
17
18const (
19	mSTUB methodTyp = 1 << iota
20	mCONNECT
21	mDELETE
22	mGET
23	mHEAD
24	mOPTIONS
25	mPATCH
26	mPOST
27	mPUT
28	mTRACE
29)
30
31var mALL = mCONNECT | mDELETE | mGET | mHEAD |
32	mOPTIONS | mPATCH | mPOST | mPUT | mTRACE
33
34var methodMap = map[string]methodTyp{
35	http.MethodConnect: mCONNECT,
36	http.MethodDelete:  mDELETE,
37	http.MethodGet:     mGET,
38	http.MethodHead:    mHEAD,
39	http.MethodOptions: mOPTIONS,
40	http.MethodPatch:   mPATCH,
41	http.MethodPost:    mPOST,
42	http.MethodPut:     mPUT,
43	http.MethodTrace:   mTRACE,
44}
45
46// RegisterMethod adds support for custom HTTP method handlers, available
47// via Router#Method and Router#MethodFunc
48func RegisterMethod(method string) {
49	if method == "" {
50		return
51	}
52	method = strings.ToUpper(method)
53	if _, ok := methodMap[method]; ok {
54		return
55	}
56	n := len(methodMap)
57	if n > strconv.IntSize-2 {
58		panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize))
59	}
60	mt := methodTyp(2 << n)
61	methodMap[method] = mt
62	mALL |= mt
63}
64
65type nodeTyp uint8
66
67const (
68	ntStatic   nodeTyp = iota // /home
69	ntRegexp                  // /{id:[0-9]+}
70	ntParam                   // /{user}
71	ntCatchAll                // /api/v1/*
72)
73
74type node struct {
75	// subroutes on the leaf node
76	subroutes Routes
77
78	// regexp matcher for regexp nodes
79	rex *regexp.Regexp
80
81	// HTTP handler endpoints on the leaf node
82	endpoints endpoints
83
84	// prefix is the common prefix we ignore
85	prefix string
86
87	// child nodes should be stored in-order for iteration,
88	// in groups of the node type.
89	children [ntCatchAll + 1]nodes
90
91	// first byte of the child prefix
92	tail byte
93
94	// node type: static, regexp, param, catchAll
95	typ nodeTyp
96
97	// first byte of the prefix
98	label byte
99}
100
101// endpoints is a mapping of http method constants to handlers
102// for a given route.
103type endpoints map[methodTyp]*endpoint
104
105type endpoint struct {
106	// endpoint handler
107	handler http.Handler
108
109	// pattern is the routing pattern for handler nodes
110	pattern string
111
112	// parameter keys recorded on handler nodes
113	paramKeys []string
114}
115
116func (s endpoints) Value(method methodTyp) *endpoint {
117	mh, ok := s[method]
118	if !ok {
119		mh = &endpoint{}
120		s[method] = mh
121	}
122	return mh
123}
124
125func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node {
126	var parent *node
127	search := pattern
128
129	for {
130		// Handle key exhaustion
131		if len(search) == 0 {
132			// Insert or update the node's leaf handler
133			n.setEndpoint(method, handler, pattern)
134			return n
135		}
136
137		// We're going to be searching for a wild node next,
138		// in this case, we need to get the tail
139		var label = search[0]
140		var segTail byte
141		var segEndIdx int
142		var segTyp nodeTyp
143		var segRexpat string
144		if label == '{' || label == '*' {
145			segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
146		}
147
148		var prefix string
149		if segTyp == ntRegexp {
150			prefix = segRexpat
151		}
152
153		// Look for the edge to attach to
154		parent = n
155		n = n.getEdge(segTyp, label, segTail, prefix)
156
157		// No edge, create one
158		if n == nil {
159			child := &node{label: label, tail: segTail, prefix: search}
160			hn := parent.addChild(child, search)
161			hn.setEndpoint(method, handler, pattern)
162
163			return hn
164		}
165
166		// Found an edge to match the pattern
167
168		if n.typ > ntStatic {
169			// We found a param node, trim the param from the search path and continue.
170			// This param/wild pattern segment would already be on the tree from a previous
171			// call to addChild when creating a new node.
172			search = search[segEndIdx:]
173			continue
174		}
175
176		// Static nodes fall below here.
177		// Determine longest prefix of the search key on match.
178		commonPrefix := longestPrefix(search, n.prefix)
179		if commonPrefix == len(n.prefix) {
180			// the common prefix is as long as the current node's prefix we're attempting to insert.
181			// keep the search going.
182			search = search[commonPrefix:]
183			continue
184		}
185
186		// Split the node
187		child := &node{
188			typ:    ntStatic,
189			prefix: search[:commonPrefix],
190		}
191		parent.replaceChild(search[0], segTail, child)
192
193		// Restore the existing node
194		n.label = n.prefix[commonPrefix]
195		n.prefix = n.prefix[commonPrefix:]
196		child.addChild(n, n.prefix)
197
198		// If the new key is a subset, set the method/handler on this node and finish.
199		search = search[commonPrefix:]
200		if len(search) == 0 {
201			child.setEndpoint(method, handler, pattern)
202			return child
203		}
204
205		// Create a new edge for the node
206		subchild := &node{
207			typ:    ntStatic,
208			label:  search[0],
209			prefix: search,
210		}
211		hn := child.addChild(subchild, search)
212		hn.setEndpoint(method, handler, pattern)
213		return hn
214	}
215}
216
217// addChild appends the new `child` node to the tree using the `pattern` as the trie key.
218// For a URL router like chi's, we split the static, param, regexp and wildcard segments
219// into different nodes. In addition, addChild will recursively call itself until every
220// pattern segment is added to the url pattern tree as individual nodes, depending on type.
221func (n *node) addChild(child *node, prefix string) *node {
222	search := prefix
223
224	// handler leaf node added to the tree is the child.
225	// this may be overridden later down the flow
226	hn := child
227
228	// Parse next segment
229	segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)
230
231	// Add child depending on next up segment
232	switch segTyp {
233
234	case ntStatic:
235		// Search prefix is all static (that is, has no params in path)
236		// noop
237
238	default:
239		// Search prefix contains a param, regexp or wildcard
240
241		if segTyp == ntRegexp {
242			rex, err := regexp.Compile(segRexpat)
243			if err != nil {
244				panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat))
245			}
246			child.prefix = segRexpat
247			child.rex = rex
248		}
249
250		if segStartIdx == 0 {
251			// Route starts with a param
252			child.typ = segTyp
253
254			if segTyp == ntCatchAll {
255				segStartIdx = -1
256			} else {
257				segStartIdx = segEndIdx
258			}
259			if segStartIdx < 0 {
260				segStartIdx = len(search)
261			}
262			child.tail = segTail // for params, we set the tail
263
264			if segStartIdx != len(search) {
265				// add static edge for the remaining part, split the end.
266				// its not possible to have adjacent param nodes, so its certainly
267				// going to be a static node next.
268
269				search = search[segStartIdx:] // advance search position
270
271				nn := &node{
272					typ:    ntStatic,
273					label:  search[0],
274					prefix: search,
275				}
276				hn = child.addChild(nn, search)
277			}
278
279		} else if segStartIdx > 0 {
280			// Route has some param
281
282			// starts with a static segment
283			child.typ = ntStatic
284			child.prefix = search[:segStartIdx]
285			child.rex = nil
286
287			// add the param edge node
288			search = search[segStartIdx:]
289
290			nn := &node{
291				typ:   segTyp,
292				label: search[0],
293				tail:  segTail,
294			}
295			hn = child.addChild(nn, search)
296
297		}
298	}
299
300	n.children[child.typ] = append(n.children[child.typ], child)
301	n.children[child.typ].Sort()
302	return hn
303}
304
305func (n *node) replaceChild(label, tail byte, child *node) {
306	for i := 0; i < len(n.children[child.typ]); i++ {
307		if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
308			n.children[child.typ][i] = child
309			n.children[child.typ][i].label = label
310			n.children[child.typ][i].tail = tail
311			return
312		}
313	}
314	panic("chi: replacing missing child")
315}
316
317func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
318	nds := n.children[ntyp]
319	for i := 0; i < len(nds); i++ {
320		if nds[i].label == label && nds[i].tail == tail {
321			if ntyp == ntRegexp && nds[i].prefix != prefix {
322				continue
323			}
324			return nds[i]
325		}
326	}
327	return nil
328}
329
330func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) {
331	// Set the handler for the method type on the node
332	if n.endpoints == nil {
333		n.endpoints = make(endpoints)
334	}
335
336	paramKeys := patParamKeys(pattern)
337
338	if method&mSTUB == mSTUB {
339		n.endpoints.Value(mSTUB).handler = handler
340	}
341	if method&mALL == mALL {
342		h := n.endpoints.Value(mALL)
343		h.handler = handler
344		h.pattern = pattern
345		h.paramKeys = paramKeys
346		for _, m := range methodMap {
347			h := n.endpoints.Value(m)
348			h.handler = handler
349			h.pattern = pattern
350			h.paramKeys = paramKeys
351		}
352	} else {
353		h := n.endpoints.Value(method)
354		h.handler = handler
355		h.pattern = pattern
356		h.paramKeys = paramKeys
357	}
358}
359
360func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) {
361	// Reset the context routing pattern and params
362	rctx.routePattern = ""
363	rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
364	rctx.routeParams.Values = rctx.routeParams.Values[:0]
365
366	// Find the routing handlers for the path
367	rn := n.findRoute(rctx, method, path)
368	if rn == nil {
369		return nil, nil, nil
370	}
371
372	// Record the routing params in the request lifecycle
373	rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...)
374	rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...)
375
376	// Record the routing pattern in the request lifecycle
377	if rn.endpoints[method].pattern != "" {
378		rctx.routePattern = rn.endpoints[method].pattern
379		rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
380	}
381
382	return rn, rn.endpoints, rn.endpoints[method].handler
383}
384
385// Recursive edge traversal by checking all nodeTyp groups along the way.
386// It's like searching through a multi-dimensional radix trie.
387func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
388	nn := n
389	search := path
390
391	for t, nds := range nn.children {
392		ntyp := nodeTyp(t)
393		if len(nds) == 0 {
394			continue
395		}
396
397		var xn *node
398		xsearch := search
399
400		var label byte
401		if search != "" {
402			label = search[0]
403		}
404
405		switch ntyp {
406		case ntStatic:
407			xn = nds.findEdge(label)
408			if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
409				continue
410			}
411			xsearch = xsearch[len(xn.prefix):]
412
413		case ntParam, ntRegexp:
414			// short-circuit and return no matching route for empty param values
415			if xsearch == "" {
416				continue
417			}
418
419			// serially loop through each node grouped by the tail delimiter
420			for idx := 0; idx < len(nds); idx++ {
421				xn = nds[idx]
422
423				// label for param nodes is the delimiter byte
424				p := strings.IndexByte(xsearch, xn.tail)
425
426				if p < 0 {
427					if xn.tail == '/' {
428						p = len(xsearch)
429					} else {
430						continue
431					}
432				} else if ntyp == ntRegexp && p == 0 {
433					continue
434				}
435
436				if ntyp == ntRegexp && xn.rex != nil {
437					if !xn.rex.MatchString(xsearch[:p]) {
438						continue
439					}
440				} else if strings.IndexByte(xsearch[:p], '/') != -1 {
441					// avoid a match across path segments
442					continue
443				}
444
445				prevlen := len(rctx.routeParams.Values)
446				rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
447				xsearch = xsearch[p:]
448
449				if len(xsearch) == 0 {
450					if xn.isLeaf() {
451						h := xn.endpoints[method]
452						if h != nil && h.handler != nil {
453							rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
454							return xn
455						}
456
457						// flag that the routing context found a route, but not a corresponding
458						// supported method
459						rctx.methodNotAllowed = true
460					}
461				}
462
463				// recursively find the next node on this branch
464				fin := xn.findRoute(rctx, method, xsearch)
465				if fin != nil {
466					return fin
467				}
468
469				// not found on this branch, reset vars
470				rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
471				xsearch = search
472			}
473
474			rctx.routeParams.Values = append(rctx.routeParams.Values, "")
475
476		default:
477			// catch-all nodes
478			rctx.routeParams.Values = append(rctx.routeParams.Values, search)
479			xn = nds[0]
480			xsearch = ""
481		}
482
483		if xn == nil {
484			continue
485		}
486
487		// did we find it yet?
488		if len(xsearch) == 0 {
489			if xn.isLeaf() {
490				h := xn.endpoints[method]
491				if h != nil && h.handler != nil {
492					rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
493					return xn
494				}
495
496				// flag that the routing context found a route, but not a corresponding
497				// supported method
498				rctx.methodNotAllowed = true
499			}
500		}
501
502		// recursively find the next node..
503		fin := xn.findRoute(rctx, method, xsearch)
504		if fin != nil {
505			return fin
506		}
507
508		// Did not find final handler, let's remove the param here if it was set
509		if xn.typ > ntStatic {
510			if len(rctx.routeParams.Values) > 0 {
511				rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
512			}
513		}
514
515	}
516
517	return nil
518}
519
520func (n *node) findEdge(ntyp nodeTyp, label byte) *node {
521	nds := n.children[ntyp]
522	num := len(nds)
523	idx := 0
524
525	switch ntyp {
526	case ntStatic, ntParam, ntRegexp:
527		i, j := 0, num-1
528		for i <= j {
529			idx = i + (j-i)/2
530			if label > nds[idx].label {
531				i = idx + 1
532			} else if label < nds[idx].label {
533				j = idx - 1
534			} else {
535				i = num // breaks cond
536			}
537		}
538		if nds[idx].label != label {
539			return nil
540		}
541		return nds[idx]
542
543	default: // catch all
544		return nds[idx]
545	}
546}
547
548func (n *node) isLeaf() bool {
549	return n.endpoints != nil
550}
551
552func (n *node) findPattern(pattern string) bool {
553	nn := n
554	for _, nds := range nn.children {
555		if len(nds) == 0 {
556			continue
557		}
558
559		n = nn.findEdge(nds[0].typ, pattern[0])
560		if n == nil {
561			continue
562		}
563
564		var idx int
565		var xpattern string
566
567		switch n.typ {
568		case ntStatic:
569			idx = longestPrefix(pattern, n.prefix)
570			if idx < len(n.prefix) {
571				continue
572			}
573
574		case ntParam, ntRegexp:
575			idx = strings.IndexByte(pattern, '}') + 1
576
577		case ntCatchAll:
578			idx = longestPrefix(pattern, "*")
579
580		default:
581			panic("chi: unknown node type")
582		}
583
584		xpattern = pattern[idx:]
585		if len(xpattern) == 0 {
586			return true
587		}
588
589		return n.findPattern(xpattern)
590	}
591	return false
592}
593
594func (n *node) routes() []Route {
595	rts := []Route{}
596
597	n.walk(func(eps endpoints, subroutes Routes) bool {
598		if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil {
599			return false
600		}
601
602		// Group methodHandlers by unique patterns
603		pats := make(map[string]endpoints)
604
605		for mt, h := range eps {
606			if h.pattern == "" {
607				continue
608			}
609			p, ok := pats[h.pattern]
610			if !ok {
611				p = endpoints{}
612				pats[h.pattern] = p
613			}
614			p[mt] = h
615		}
616
617		for p, mh := range pats {
618			hs := make(map[string]http.Handler)
619			if mh[mALL] != nil && mh[mALL].handler != nil {
620				hs["*"] = mh[mALL].handler
621			}
622
623			for mt, h := range mh {
624				if h.handler == nil {
625					continue
626				}
627				m := methodTypString(mt)
628				if m == "" {
629					continue
630				}
631				hs[m] = h.handler
632			}
633
634			rt := Route{subroutes, hs, p}
635			rts = append(rts, rt)
636		}
637
638		return false
639	})
640
641	return rts
642}
643
644func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool {
645	// Visit the leaf values if any
646	if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) {
647		return true
648	}
649
650	// Recurse on the children
651	for _, ns := range n.children {
652		for _, cn := range ns {
653			if cn.walk(fn) {
654				return true
655			}
656		}
657	}
658	return false
659}
660
661// patNextSegment returns the next segment details from a pattern:
662// node type, param key, regexp string, param tail byte, param starting index, param ending index
663func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
664	ps := strings.Index(pattern, "{")
665	ws := strings.Index(pattern, "*")
666
667	if ps < 0 && ws < 0 {
668		return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing
669	}
670
671	// Sanity check
672	if ps >= 0 && ws >= 0 && ws < ps {
673		panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
674	}
675
676	var tail byte = '/' // Default endpoint tail to / byte
677
678	if ps >= 0 {
679		// Param/Regexp pattern is next
680		nt := ntParam
681
682		// Read to closing } taking into account opens and closes in curl count (cc)
683		cc := 0
684		pe := ps
685		for i, c := range pattern[ps:] {
686			if c == '{' {
687				cc++
688			} else if c == '}' {
689				cc--
690				if cc == 0 {
691					pe = ps + i
692					break
693				}
694			}
695		}
696		if pe == ps {
697			panic("chi: route param closing delimiter '}' is missing")
698		}
699
700		key := pattern[ps+1 : pe]
701		pe++ // set end to next position
702
703		if pe < len(pattern) {
704			tail = pattern[pe]
705		}
706
707		var rexpat string
708		if idx := strings.Index(key, ":"); idx >= 0 {
709			nt = ntRegexp
710			rexpat = key[idx+1:]
711			key = key[:idx]
712		}
713
714		if len(rexpat) > 0 {
715			if rexpat[0] != '^' {
716				rexpat = "^" + rexpat
717			}
718			if rexpat[len(rexpat)-1] != '$' {
719				rexpat += "$"
720			}
721		}
722
723		return nt, key, rexpat, tail, ps, pe
724	}
725
726	// Wildcard pattern as finale
727	if ws < len(pattern)-1 {
728		panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
729	}
730	return ntCatchAll, "*", "", 0, ws, len(pattern)
731}
732
733func patParamKeys(pattern string) []string {
734	pat := pattern
735	paramKeys := []string{}
736	for {
737		ptyp, paramKey, _, _, _, e := patNextSegment(pat)
738		if ptyp == ntStatic {
739			return paramKeys
740		}
741		for i := 0; i < len(paramKeys); i++ {
742			if paramKeys[i] == paramKey {
743				panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
744			}
745		}
746		paramKeys = append(paramKeys, paramKey)
747		pat = pat[e:]
748	}
749}
750
751// longestPrefix finds the length of the shared prefix
752// of two strings
753func longestPrefix(k1, k2 string) int {
754	max := len(k1)
755	if l := len(k2); l < max {
756		max = l
757	}
758	var i int
759	for i = 0; i < max; i++ {
760		if k1[i] != k2[i] {
761			break
762		}
763	}
764	return i
765}
766
767func methodTypString(method methodTyp) string {
768	for s, t := range methodMap {
769		if method == t {
770			return s
771		}
772	}
773	return ""
774}
775
776type nodes []*node
777
778// Sort the list of nodes by label
779func (ns nodes) Sort()              { sort.Sort(ns); ns.tailSort() }
780func (ns nodes) Len() int           { return len(ns) }
781func (ns nodes) Swap(i, j int)      { ns[i], ns[j] = ns[j], ns[i] }
782func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }
783
784// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes.
785// The list order determines the traversal order.
786func (ns nodes) tailSort() {
787	for i := len(ns) - 1; i >= 0; i-- {
788		if ns[i].typ > ntStatic && ns[i].tail == '/' {
789			ns.Swap(i, len(ns)-1)
790			return
791		}
792	}
793}
794
795func (ns nodes) findEdge(label byte) *node {
796	num := len(ns)
797	idx := 0
798	i, j := 0, num-1
799	for i <= j {
800		idx = i + (j-i)/2
801		if label > ns[idx].label {
802			i = idx + 1
803		} else if label < ns[idx].label {
804			j = idx - 1
805		} else {
806			i = num // breaks cond
807		}
808	}
809	if ns[idx].label != label {
810		return nil
811	}
812	return ns[idx]
813}
814
815// Route describes the details of a routing handler.
816// Handlers map key is an HTTP method
817type Route struct {
818	SubRoutes Routes
819	Handlers  map[string]http.Handler
820	Pattern   string
821}
822
823// WalkFunc is the type of the function called for each method and route visited by Walk.
824type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error
825
826// Walk walks any router tree that implements Routes interface.
827func Walk(r Routes, walkFn WalkFunc) error {
828	return walk(r, walkFn, "")
829}
830
831func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error {
832	for _, route := range r.Routes() {
833		mws := make([]func(http.Handler) http.Handler, len(parentMw))
834		copy(mws, parentMw)
835		mws = append(mws, r.Middlewares()...)
836
837		if route.SubRoutes != nil {
838			if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil {
839				return err
840			}
841			continue
842		}
843
844		for method, handler := range route.Handlers {
845			if method == "*" {
846				// Ignore a "catchAll" method, since we pass down all the specific methods for each route.
847				continue
848			}
849
850			fullRoute := parentRoute + route.Pattern
851			fullRoute = strings.Replace(fullRoute, "/*/", "/", -1)
852
853			if chain, ok := handler.(*ChainHandler); ok {
854				if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil {
855					return err
856				}
857			} else {
858				if err := walkFn(method, fullRoute, handler, mws...); err != nil {
859					return err
860				}
861			}
862		}
863	}
864
865	return nil
866}
867