1package chi
2
3// Radix tree implementation below is a based on the original work by
4// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go
5// (MIT licensed). It's been heavily modified for use as a HTTP routing tree.
6
7import (
8	"fmt"
9	"math"
10	"net/http"
11	"regexp"
12	"sort"
13	"strconv"
14	"strings"
15)
16
17type methodTyp int
18
19const (
20	mSTUB methodTyp = 1 << iota
21	mCONNECT
22	mDELETE
23	mGET
24	mHEAD
25	mOPTIONS
26	mPATCH
27	mPOST
28	mPUT
29	mTRACE
30)
31
32var mALL = mCONNECT | mDELETE | mGET | mHEAD |
33	mOPTIONS | mPATCH | mPOST | mPUT | mTRACE
34
35var methodMap = map[string]methodTyp{
36	http.MethodConnect: mCONNECT,
37	http.MethodDelete:  mDELETE,
38	http.MethodGet:     mGET,
39	http.MethodHead:    mHEAD,
40	http.MethodOptions: mOPTIONS,
41	http.MethodPatch:   mPATCH,
42	http.MethodPost:    mPOST,
43	http.MethodPut:     mPUT,
44	http.MethodTrace:   mTRACE,
45}
46
47// RegisterMethod adds support for custom HTTP method handlers, available
48// via Router#Method and Router#MethodFunc
49func RegisterMethod(method string) {
50	if method == "" {
51		return
52	}
53	method = strings.ToUpper(method)
54	if _, ok := methodMap[method]; ok {
55		return
56	}
57	n := len(methodMap)
58	if n > strconv.IntSize {
59		panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize))
60	}
61	mt := methodTyp(math.Exp2(float64(n)))
62	methodMap[method] = mt
63	mALL |= mt
64}
65
66type nodeTyp uint8
67
68const (
69	ntStatic   nodeTyp = iota // /home
70	ntRegexp                  // /{id:[0-9]+}
71	ntParam                   // /{user}
72	ntCatchAll                // /api/v1/*
73)
74
75type node struct {
76	// node type: static, regexp, param, catchAll
77	typ nodeTyp
78
79	// first byte of the prefix
80	label byte
81
82	// first byte of the child prefix
83	tail byte
84
85	// prefix is the common prefix we ignore
86	prefix string
87
88	// regexp matcher for regexp nodes
89	rex *regexp.Regexp
90
91	// HTTP handler endpoints on the leaf node
92	endpoints endpoints
93
94	// subroutes on the leaf node
95	subroutes Routes
96
97	// child nodes should be stored in-order for iteration,
98	// in groups of the node type.
99	children [ntCatchAll + 1]nodes
100}
101
102// endpoints is a mapping of http method constants to handlers
103// for a given route.
104type endpoints map[methodTyp]*endpoint
105
106type endpoint struct {
107	// endpoint handler
108	handler http.Handler
109
110	// pattern is the routing pattern for handler nodes
111	pattern string
112
113	// parameter keys recorded on handler nodes
114	paramKeys []string
115}
116
117func (s endpoints) Value(method methodTyp) *endpoint {
118	mh, ok := s[method]
119	if !ok {
120		mh = &endpoint{}
121		s[method] = mh
122	}
123	return mh
124}
125
126func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node {
127	var parent *node
128	search := pattern
129
130	for {
131		// Handle key exhaustion
132		if len(search) == 0 {
133			// Insert or update the node's leaf handler
134			n.setEndpoint(method, handler, pattern)
135			return n
136		}
137
138		// We're going to be searching for a wild node next,
139		// in this case, we need to get the tail
140		var label = search[0]
141		var segTail byte
142		var segEndIdx int
143		var segTyp nodeTyp
144		var segRexpat string
145		if label == '{' || label == '*' {
146			segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
147		}
148
149		var prefix string
150		if segTyp == ntRegexp {
151			prefix = segRexpat
152		}
153
154		// Look for the edge to attach to
155		parent = n
156		n = n.getEdge(segTyp, label, segTail, prefix)
157
158		// No edge, create one
159		if n == nil {
160			child := &node{label: label, tail: segTail, prefix: search}
161			hn := parent.addChild(child, search)
162			hn.setEndpoint(method, handler, pattern)
163
164			return hn
165		}
166
167		// Found an edge to match the pattern
168
169		if n.typ > ntStatic {
170			// We found a param node, trim the param from the search path and continue.
171			// This param/wild pattern segment would already be on the tree from a previous
172			// call to addChild when creating a new node.
173			search = search[segEndIdx:]
174			continue
175		}
176
177		// Static nodes fall below here.
178		// Determine longest prefix of the search key on match.
179		commonPrefix := longestPrefix(search, n.prefix)
180		if commonPrefix == len(n.prefix) {
181			// the common prefix is as long as the current node's prefix we're attempting to insert.
182			// keep the search going.
183			search = search[commonPrefix:]
184			continue
185		}
186
187		// Split the node
188		child := &node{
189			typ:    ntStatic,
190			prefix: search[:commonPrefix],
191		}
192		parent.replaceChild(search[0], segTail, child)
193
194		// Restore the existing node
195		n.label = n.prefix[commonPrefix]
196		n.prefix = n.prefix[commonPrefix:]
197		child.addChild(n, n.prefix)
198
199		// If the new key is a subset, set the method/handler on this node and finish.
200		search = search[commonPrefix:]
201		if len(search) == 0 {
202			child.setEndpoint(method, handler, pattern)
203			return child
204		}
205
206		// Create a new edge for the node
207		subchild := &node{
208			typ:    ntStatic,
209			label:  search[0],
210			prefix: search,
211		}
212		hn := child.addChild(subchild, search)
213		hn.setEndpoint(method, handler, pattern)
214		return hn
215	}
216}
217
218// addChild appends the new `child` node to the tree using the `pattern` as the trie key.
219// For a URL router like chi's, we split the static, param, regexp and wildcard segments
220// into different nodes. In addition, addChild will recursively call itself until every
221// pattern segment is added to the url pattern tree as individual nodes, depending on type.
222func (n *node) addChild(child *node, prefix string) *node {
223	search := prefix
224
225	// handler leaf node added to the tree is the child.
226	// this may be overridden later down the flow
227	hn := child
228
229	// Parse next segment
230	segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)
231
232	// Add child depending on next up segment
233	switch segTyp {
234
235	case ntStatic:
236		// Search prefix is all static (that is, has no params in path)
237		// noop
238
239	default:
240		// Search prefix contains a param, regexp or wildcard
241
242		if segTyp == ntRegexp {
243			rex, err := regexp.Compile(segRexpat)
244			if err != nil {
245				panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat))
246			}
247			child.prefix = segRexpat
248			child.rex = rex
249		}
250
251		if segStartIdx == 0 {
252			// Route starts with a param
253			child.typ = segTyp
254
255			if segTyp == ntCatchAll {
256				segStartIdx = -1
257			} else {
258				segStartIdx = segEndIdx
259			}
260			if segStartIdx < 0 {
261				segStartIdx = len(search)
262			}
263			child.tail = segTail // for params, we set the tail
264
265			if segStartIdx != len(search) {
266				// add static edge for the remaining part, split the end.
267				// its not possible to have adjacent param nodes, so its certainly
268				// going to be a static node next.
269
270				search = search[segStartIdx:] // advance search position
271
272				nn := &node{
273					typ:    ntStatic,
274					label:  search[0],
275					prefix: search,
276				}
277				hn = child.addChild(nn, search)
278			}
279
280		} else if segStartIdx > 0 {
281			// Route has some param
282
283			// starts with a static segment
284			child.typ = ntStatic
285			child.prefix = search[:segStartIdx]
286			child.rex = nil
287
288			// add the param edge node
289			search = search[segStartIdx:]
290
291			nn := &node{
292				typ:   segTyp,
293				label: search[0],
294				tail:  segTail,
295			}
296			hn = child.addChild(nn, search)
297
298		}
299	}
300
301	n.children[child.typ] = append(n.children[child.typ], child)
302	n.children[child.typ].Sort()
303	return hn
304}
305
306func (n *node) replaceChild(label, tail byte, child *node) {
307	for i := 0; i < len(n.children[child.typ]); i++ {
308		if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
309			n.children[child.typ][i] = child
310			n.children[child.typ][i].label = label
311			n.children[child.typ][i].tail = tail
312			return
313		}
314	}
315	panic("chi: replacing missing child")
316}
317
318func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
319	nds := n.children[ntyp]
320	for i := 0; i < len(nds); i++ {
321		if nds[i].label == label && nds[i].tail == tail {
322			if ntyp == ntRegexp && nds[i].prefix != prefix {
323				continue
324			}
325			return nds[i]
326		}
327	}
328	return nil
329}
330
331func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) {
332	// Set the handler for the method type on the node
333	if n.endpoints == nil {
334		n.endpoints = make(endpoints)
335	}
336
337	paramKeys := patParamKeys(pattern)
338
339	if method&mSTUB == mSTUB {
340		n.endpoints.Value(mSTUB).handler = handler
341	}
342	if method&mALL == mALL {
343		h := n.endpoints.Value(mALL)
344		h.handler = handler
345		h.pattern = pattern
346		h.paramKeys = paramKeys
347		for _, m := range methodMap {
348			h := n.endpoints.Value(m)
349			h.handler = handler
350			h.pattern = pattern
351			h.paramKeys = paramKeys
352		}
353	} else {
354		h := n.endpoints.Value(method)
355		h.handler = handler
356		h.pattern = pattern
357		h.paramKeys = paramKeys
358	}
359}
360
361func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) {
362	// Reset the context routing pattern and params
363	rctx.routePattern = ""
364	rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
365	rctx.routeParams.Values = rctx.routeParams.Values[:0]
366
367	// Find the routing handlers for the path
368	rn := n.findRoute(rctx, method, path)
369	if rn == nil {
370		return nil, nil, nil
371	}
372
373	// Record the routing params in the request lifecycle
374	rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...)
375	rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...)
376
377	// Record the routing pattern in the request lifecycle
378	if rn.endpoints[method].pattern != "" {
379		rctx.routePattern = rn.endpoints[method].pattern
380		rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
381	}
382
383	return rn, rn.endpoints, rn.endpoints[method].handler
384}
385
386// Recursive edge traversal by checking all nodeTyp groups along the way.
387// It's like searching through a multi-dimensional radix trie.
388func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node {
389	nn := n
390	search := path
391
392	for t, nds := range nn.children {
393		ntyp := nodeTyp(t)
394		if len(nds) == 0 {
395			continue
396		}
397
398		var xn *node
399		xsearch := search
400
401		var label byte
402		if search != "" {
403			label = search[0]
404		}
405
406		switch ntyp {
407		case ntStatic:
408			xn = nds.findEdge(label)
409			if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
410				continue
411			}
412			xsearch = xsearch[len(xn.prefix):]
413
414		case ntParam, ntRegexp:
415			// short-circuit and return no matching route for empty param values
416			if xsearch == "" {
417				continue
418			}
419
420			// serially loop through each node grouped by the tail delimiter
421			for idx := 0; idx < len(nds); idx++ {
422				xn = nds[idx]
423
424				// label for param nodes is the delimiter byte
425				p := strings.IndexByte(xsearch, xn.tail)
426
427				if p < 0 {
428					if xn.tail == '/' {
429						p = len(xsearch)
430					} else {
431						continue
432					}
433				}
434
435				if ntyp == ntRegexp && xn.rex != nil {
436					if !xn.rex.MatchString(xsearch[:p]) {
437						continue
438					}
439				} else if strings.IndexByte(xsearch[:p], '/') != -1 {
440					// avoid a match across path segments
441					continue
442				}
443
444				prevlen := len(rctx.routeParams.Values)
445				rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
446				xsearch = xsearch[p:]
447
448				if len(xsearch) == 0 {
449					if xn.isLeaf() {
450						h := xn.endpoints[method]
451						if h != nil && h.handler != nil {
452							rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
453							return xn
454						}
455
456						// flag that the routing context found a route, but not a corresponding
457						// supported method
458						rctx.methodNotAllowed = true
459					}
460				}
461
462				// recursively find the next node on this branch
463				fin := xn.findRoute(rctx, method, xsearch)
464				if fin != nil {
465					return fin
466				}
467
468				// not found on this branch, reset vars
469				rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
470				xsearch = search
471			}
472
473			rctx.routeParams.Values = append(rctx.routeParams.Values, "")
474
475		default:
476			// catch-all nodes
477			rctx.routeParams.Values = append(rctx.routeParams.Values, search)
478			xn = nds[0]
479			xsearch = ""
480		}
481
482		if xn == nil {
483			continue
484		}
485
486		// did we find it yet?
487		if len(xsearch) == 0 {
488			if xn.isLeaf() {
489				h := xn.endpoints[method]
490				if h != nil && h.handler != nil {
491					rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
492					return xn
493				}
494
495				// flag that the routing context found a route, but not a corresponding
496				// supported method
497				rctx.methodNotAllowed = true
498			}
499		}
500
501		// recursively find the next node..
502		fin := xn.findRoute(rctx, method, xsearch)
503		if fin != nil {
504			return fin
505		}
506
507		// Did not find final handler, let's remove the param here if it was set
508		if xn.typ > ntStatic {
509			if len(rctx.routeParams.Values) > 0 {
510				rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
511			}
512		}
513
514	}
515
516	return nil
517}
518
519func (n *node) findEdge(ntyp nodeTyp, label byte) *node {
520	nds := n.children[ntyp]
521	num := len(nds)
522	idx := 0
523
524	switch ntyp {
525	case ntStatic, ntParam, ntRegexp:
526		i, j := 0, num-1
527		for i <= j {
528			idx = i + (j-i)/2
529			if label > nds[idx].label {
530				i = idx + 1
531			} else if label < nds[idx].label {
532				j = idx - 1
533			} else {
534				i = num // breaks cond
535			}
536		}
537		if nds[idx].label != label {
538			return nil
539		}
540		return nds[idx]
541
542	default: // catch all
543		return nds[idx]
544	}
545}
546
547func (n *node) isLeaf() bool {
548	return n.endpoints != nil
549}
550
551func (n *node) findPattern(pattern string) bool {
552	nn := n
553	for _, nds := range nn.children {
554		if len(nds) == 0 {
555			continue
556		}
557
558		n = nn.findEdge(nds[0].typ, pattern[0])
559		if n == nil {
560			continue
561		}
562
563		var idx int
564		var xpattern string
565
566		switch n.typ {
567		case ntStatic:
568			idx = longestPrefix(pattern, n.prefix)
569			if idx < len(n.prefix) {
570				continue
571			}
572
573		case ntParam, ntRegexp:
574			idx = strings.IndexByte(pattern, '}') + 1
575
576		case ntCatchAll:
577			idx = longestPrefix(pattern, "*")
578
579		default:
580			panic("chi: unknown node type")
581		}
582
583		xpattern = pattern[idx:]
584		if len(xpattern) == 0 {
585			return true
586		}
587
588		return n.findPattern(xpattern)
589	}
590	return false
591}
592
593func (n *node) routes() []Route {
594	rts := []Route{}
595
596	n.walk(func(eps endpoints, subroutes Routes) bool {
597		if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil {
598			return false
599		}
600
601		// Group methodHandlers by unique patterns
602		pats := make(map[string]endpoints)
603
604		for mt, h := range eps {
605			if h.pattern == "" {
606				continue
607			}
608			p, ok := pats[h.pattern]
609			if !ok {
610				p = endpoints{}
611				pats[h.pattern] = p
612			}
613			p[mt] = h
614		}
615
616		for p, mh := range pats {
617			hs := make(map[string]http.Handler)
618			if mh[mALL] != nil && mh[mALL].handler != nil {
619				hs["*"] = mh[mALL].handler
620			}
621
622			for mt, h := range mh {
623				if h.handler == nil {
624					continue
625				}
626				m := methodTypString(mt)
627				if m == "" {
628					continue
629				}
630				hs[m] = h.handler
631			}
632
633			rt := Route{p, hs, subroutes}
634			rts = append(rts, rt)
635		}
636
637		return false
638	})
639
640	return rts
641}
642
643func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool {
644	// Visit the leaf values if any
645	if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) {
646		return true
647	}
648
649	// Recurse on the children
650	for _, ns := range n.children {
651		for _, cn := range ns {
652			if cn.walk(fn) {
653				return true
654			}
655		}
656	}
657	return false
658}
659
660// patNextSegment returns the next segment details from a pattern:
661// node type, param key, regexp string, param tail byte, param starting index, param ending index
662func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
663	ps := strings.Index(pattern, "{")
664	ws := strings.Index(pattern, "*")
665
666	if ps < 0 && ws < 0 {
667		return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing
668	}
669
670	// Sanity check
671	if ps >= 0 && ws >= 0 && ws < ps {
672		panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
673	}
674
675	var tail byte = '/' // Default endpoint tail to / byte
676
677	if ps >= 0 {
678		// Param/Regexp pattern is next
679		nt := ntParam
680
681		// Read to closing } taking into account opens and closes in curl count (cc)
682		cc := 0
683		pe := ps
684		for i, c := range pattern[ps:] {
685			if c == '{' {
686				cc++
687			} else if c == '}' {
688				cc--
689				if cc == 0 {
690					pe = ps + i
691					break
692				}
693			}
694		}
695		if pe == ps {
696			panic("chi: route param closing delimiter '}' is missing")
697		}
698
699		key := pattern[ps+1 : pe]
700		pe++ // set end to next position
701
702		if pe < len(pattern) {
703			tail = pattern[pe]
704		}
705
706		var rexpat string
707		if idx := strings.Index(key, ":"); idx >= 0 {
708			nt = ntRegexp
709			rexpat = key[idx+1:]
710			key = key[:idx]
711		}
712
713		if len(rexpat) > 0 {
714			if rexpat[0] != '^' {
715				rexpat = "^" + rexpat
716			}
717			if rexpat[len(rexpat)-1] != '$' {
718				rexpat += "$"
719			}
720		}
721
722		return nt, key, rexpat, tail, ps, pe
723	}
724
725	// Wildcard pattern as finale
726	if ws < len(pattern)-1 {
727		panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
728	}
729	return ntCatchAll, "*", "", 0, ws, len(pattern)
730}
731
732func patParamKeys(pattern string) []string {
733	pat := pattern
734	paramKeys := []string{}
735	for {
736		ptyp, paramKey, _, _, _, e := patNextSegment(pat)
737		if ptyp == ntStatic {
738			return paramKeys
739		}
740		for i := 0; i < len(paramKeys); i++ {
741			if paramKeys[i] == paramKey {
742				panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
743			}
744		}
745		paramKeys = append(paramKeys, paramKey)
746		pat = pat[e:]
747	}
748}
749
750// longestPrefix finds the length of the shared prefix
751// of two strings
752func longestPrefix(k1, k2 string) int {
753	max := len(k1)
754	if l := len(k2); l < max {
755		max = l
756	}
757	var i int
758	for i = 0; i < max; i++ {
759		if k1[i] != k2[i] {
760			break
761		}
762	}
763	return i
764}
765
766func methodTypString(method methodTyp) string {
767	for s, t := range methodMap {
768		if method == t {
769			return s
770		}
771	}
772	return ""
773}
774
775type nodes []*node
776
777// Sort the list of nodes by label
778func (ns nodes) Sort()              { sort.Sort(ns); ns.tailSort() }
779func (ns nodes) Len() int           { return len(ns) }
780func (ns nodes) Swap(i, j int)      { ns[i], ns[j] = ns[j], ns[i] }
781func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }
782
783// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes.
784// The list order determines the traversal order.
785func (ns nodes) tailSort() {
786	for i := len(ns) - 1; i >= 0; i-- {
787		if ns[i].typ > ntStatic && ns[i].tail == '/' {
788			ns.Swap(i, len(ns)-1)
789			return
790		}
791	}
792}
793
794func (ns nodes) findEdge(label byte) *node {
795	num := len(ns)
796	idx := 0
797	i, j := 0, num-1
798	for i <= j {
799		idx = i + (j-i)/2
800		if label > ns[idx].label {
801			i = idx + 1
802		} else if label < ns[idx].label {
803			j = idx - 1
804		} else {
805			i = num // breaks cond
806		}
807	}
808	if ns[idx].label != label {
809		return nil
810	}
811	return ns[idx]
812}
813
814// Route describes the details of a routing handler.
815// Handlers map key is an HTTP method
816type Route struct {
817	Pattern   string
818	Handlers  map[string]http.Handler
819	SubRoutes Routes
820}
821
822// WalkFunc is the type of the function called for each method and route visited by Walk.
823type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error
824
825// Walk walks any router tree that implements Routes interface.
826func Walk(r Routes, walkFn WalkFunc) error {
827	return walk(r, walkFn, "")
828}
829
830func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error {
831	for _, route := range r.Routes() {
832		mws := make([]func(http.Handler) http.Handler, len(parentMw))
833		copy(mws, parentMw)
834		mws = append(mws, r.Middlewares()...)
835
836		if route.SubRoutes != nil {
837			if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil {
838				return err
839			}
840			continue
841		}
842
843		for method, handler := range route.Handlers {
844			if method == "*" {
845				// Ignore a "catchAll" method, since we pass down all the specific methods for each route.
846				continue
847			}
848
849			fullRoute := parentRoute + route.Pattern
850			fullRoute = strings.Replace(fullRoute, "/*/", "/", -1)
851
852			if chain, ok := handler.(*ChainHandler); ok {
853				if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil {
854					return err
855				}
856			} else {
857				if err := walkFn(method, fullRoute, handler, mws...); err != nil {
858					return err
859				}
860			}
861		}
862	}
863
864	return nil
865}
866