1package chi 2 3// Radix tree implementation below is a based on the original work by 4// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go 5// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. 6 7import ( 8 "fmt" 9 "math" 10 "net/http" 11 "regexp" 12 "sort" 13 "strconv" 14 "strings" 15) 16 17type methodTyp int 18 19const ( 20 mSTUB methodTyp = 1 << iota 21 mCONNECT 22 mDELETE 23 mGET 24 mHEAD 25 mOPTIONS 26 mPATCH 27 mPOST 28 mPUT 29 mTRACE 30) 31 32var mALL = mCONNECT | mDELETE | mGET | mHEAD | 33 mOPTIONS | mPATCH | mPOST | mPUT | mTRACE 34 35var methodMap = map[string]methodTyp{ 36 http.MethodConnect: mCONNECT, 37 http.MethodDelete: mDELETE, 38 http.MethodGet: mGET, 39 http.MethodHead: mHEAD, 40 http.MethodOptions: mOPTIONS, 41 http.MethodPatch: mPATCH, 42 http.MethodPost: mPOST, 43 http.MethodPut: mPUT, 44 http.MethodTrace: mTRACE, 45} 46 47// RegisterMethod adds support for custom HTTP method handlers, available 48// via Router#Method and Router#MethodFunc 49func RegisterMethod(method string) { 50 if method == "" { 51 return 52 } 53 method = strings.ToUpper(method) 54 if _, ok := methodMap[method]; ok { 55 return 56 } 57 n := len(methodMap) 58 if n > strconv.IntSize { 59 panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize)) 60 } 61 mt := methodTyp(math.Exp2(float64(n))) 62 methodMap[method] = mt 63 mALL |= mt 64} 65 66type nodeTyp uint8 67 68const ( 69 ntStatic nodeTyp = iota // /home 70 ntRegexp // /{id:[0-9]+} 71 ntParam // /{user} 72 ntCatchAll // /api/v1/* 73) 74 75type node struct { 76 // node type: static, regexp, param, catchAll 77 typ nodeTyp 78 79 // first byte of the prefix 80 label byte 81 82 // first byte of the child prefix 83 tail byte 84 85 // prefix is the common prefix we ignore 86 prefix string 87 88 // regexp matcher for regexp nodes 89 rex *regexp.Regexp 90 91 // HTTP handler endpoints on the leaf node 92 endpoints endpoints 93 94 // subroutes on the leaf node 95 subroutes Routes 96 97 // child nodes should be stored in-order for iteration, 98 // in groups of the node type. 99 children [ntCatchAll + 1]nodes 100} 101 102// endpoints is a mapping of http method constants to handlers 103// for a given route. 104type endpoints map[methodTyp]*endpoint 105 106type endpoint struct { 107 // endpoint handler 108 handler http.Handler 109 110 // pattern is the routing pattern for handler nodes 111 pattern string 112 113 // parameter keys recorded on handler nodes 114 paramKeys []string 115} 116 117func (s endpoints) Value(method methodTyp) *endpoint { 118 mh, ok := s[method] 119 if !ok { 120 mh = &endpoint{} 121 s[method] = mh 122 } 123 return mh 124} 125 126func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { 127 var parent *node 128 search := pattern 129 130 for { 131 // Handle key exhaustion 132 if len(search) == 0 { 133 // Insert or update the node's leaf handler 134 n.setEndpoint(method, handler, pattern) 135 return n 136 } 137 138 // We're going to be searching for a wild node next, 139 // in this case, we need to get the tail 140 var label = search[0] 141 var segTail byte 142 var segEndIdx int 143 var segTyp nodeTyp 144 var segRexpat string 145 if label == '{' || label == '*' { 146 segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) 147 } 148 149 var prefix string 150 if segTyp == ntRegexp { 151 prefix = segRexpat 152 } 153 154 // Look for the edge to attach to 155 parent = n 156 n = n.getEdge(segTyp, label, segTail, prefix) 157 158 // No edge, create one 159 if n == nil { 160 child := &node{label: label, tail: segTail, prefix: search} 161 hn := parent.addChild(child, search) 162 hn.setEndpoint(method, handler, pattern) 163 164 return hn 165 } 166 167 // Found an edge to match the pattern 168 169 if n.typ > ntStatic { 170 // We found a param node, trim the param from the search path and continue. 171 // This param/wild pattern segment would already be on the tree from a previous 172 // call to addChild when creating a new node. 173 search = search[segEndIdx:] 174 continue 175 } 176 177 // Static nodes fall below here. 178 // Determine longest prefix of the search key on match. 179 commonPrefix := longestPrefix(search, n.prefix) 180 if commonPrefix == len(n.prefix) { 181 // the common prefix is as long as the current node's prefix we're attempting to insert. 182 // keep the search going. 183 search = search[commonPrefix:] 184 continue 185 } 186 187 // Split the node 188 child := &node{ 189 typ: ntStatic, 190 prefix: search[:commonPrefix], 191 } 192 parent.replaceChild(search[0], segTail, child) 193 194 // Restore the existing node 195 n.label = n.prefix[commonPrefix] 196 n.prefix = n.prefix[commonPrefix:] 197 child.addChild(n, n.prefix) 198 199 // If the new key is a subset, set the method/handler on this node and finish. 200 search = search[commonPrefix:] 201 if len(search) == 0 { 202 child.setEndpoint(method, handler, pattern) 203 return child 204 } 205 206 // Create a new edge for the node 207 subchild := &node{ 208 typ: ntStatic, 209 label: search[0], 210 prefix: search, 211 } 212 hn := child.addChild(subchild, search) 213 hn.setEndpoint(method, handler, pattern) 214 return hn 215 } 216} 217 218// addChild appends the new `child` node to the tree using the `pattern` as the trie key. 219// For a URL router like chi's, we split the static, param, regexp and wildcard segments 220// into different nodes. In addition, addChild will recursively call itself until every 221// pattern segment is added to the url pattern tree as individual nodes, depending on type. 222func (n *node) addChild(child *node, prefix string) *node { 223 search := prefix 224 225 // handler leaf node added to the tree is the child. 226 // this may be overridden later down the flow 227 hn := child 228 229 // Parse next segment 230 segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) 231 232 // Add child depending on next up segment 233 switch segTyp { 234 235 case ntStatic: 236 // Search prefix is all static (that is, has no params in path) 237 // noop 238 239 default: 240 // Search prefix contains a param, regexp or wildcard 241 242 if segTyp == ntRegexp { 243 rex, err := regexp.Compile(segRexpat) 244 if err != nil { 245 panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat)) 246 } 247 child.prefix = segRexpat 248 child.rex = rex 249 } 250 251 if segStartIdx == 0 { 252 // Route starts with a param 253 child.typ = segTyp 254 255 if segTyp == ntCatchAll { 256 segStartIdx = -1 257 } else { 258 segStartIdx = segEndIdx 259 } 260 if segStartIdx < 0 { 261 segStartIdx = len(search) 262 } 263 child.tail = segTail // for params, we set the tail 264 265 if segStartIdx != len(search) { 266 // add static edge for the remaining part, split the end. 267 // its not possible to have adjacent param nodes, so its certainly 268 // going to be a static node next. 269 270 search = search[segStartIdx:] // advance search position 271 272 nn := &node{ 273 typ: ntStatic, 274 label: search[0], 275 prefix: search, 276 } 277 hn = child.addChild(nn, search) 278 } 279 280 } else if segStartIdx > 0 { 281 // Route has some param 282 283 // starts with a static segment 284 child.typ = ntStatic 285 child.prefix = search[:segStartIdx] 286 child.rex = nil 287 288 // add the param edge node 289 search = search[segStartIdx:] 290 291 nn := &node{ 292 typ: segTyp, 293 label: search[0], 294 tail: segTail, 295 } 296 hn = child.addChild(nn, search) 297 298 } 299 } 300 301 n.children[child.typ] = append(n.children[child.typ], child) 302 n.children[child.typ].Sort() 303 return hn 304} 305 306func (n *node) replaceChild(label, tail byte, child *node) { 307 for i := 0; i < len(n.children[child.typ]); i++ { 308 if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { 309 n.children[child.typ][i] = child 310 n.children[child.typ][i].label = label 311 n.children[child.typ][i].tail = tail 312 return 313 } 314 } 315 panic("chi: replacing missing child") 316} 317 318func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { 319 nds := n.children[ntyp] 320 for i := 0; i < len(nds); i++ { 321 if nds[i].label == label && nds[i].tail == tail { 322 if ntyp == ntRegexp && nds[i].prefix != prefix { 323 continue 324 } 325 return nds[i] 326 } 327 } 328 return nil 329} 330 331func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { 332 // Set the handler for the method type on the node 333 if n.endpoints == nil { 334 n.endpoints = make(endpoints) 335 } 336 337 paramKeys := patParamKeys(pattern) 338 339 if method&mSTUB == mSTUB { 340 n.endpoints.Value(mSTUB).handler = handler 341 } 342 if method&mALL == mALL { 343 h := n.endpoints.Value(mALL) 344 h.handler = handler 345 h.pattern = pattern 346 h.paramKeys = paramKeys 347 for _, m := range methodMap { 348 h := n.endpoints.Value(m) 349 h.handler = handler 350 h.pattern = pattern 351 h.paramKeys = paramKeys 352 } 353 } else { 354 h := n.endpoints.Value(method) 355 h.handler = handler 356 h.pattern = pattern 357 h.paramKeys = paramKeys 358 } 359} 360 361func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { 362 // Reset the context routing pattern and params 363 rctx.routePattern = "" 364 rctx.routeParams.Keys = rctx.routeParams.Keys[:0] 365 rctx.routeParams.Values = rctx.routeParams.Values[:0] 366 367 // Find the routing handlers for the path 368 rn := n.findRoute(rctx, method, path) 369 if rn == nil { 370 return nil, nil, nil 371 } 372 373 // Record the routing params in the request lifecycle 374 rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) 375 rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) 376 377 // Record the routing pattern in the request lifecycle 378 if rn.endpoints[method].pattern != "" { 379 rctx.routePattern = rn.endpoints[method].pattern 380 rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) 381 } 382 383 return rn, rn.endpoints, rn.endpoints[method].handler 384} 385 386// Recursive edge traversal by checking all nodeTyp groups along the way. 387// It's like searching through a multi-dimensional radix trie. 388func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { 389 nn := n 390 search := path 391 392 for t, nds := range nn.children { 393 ntyp := nodeTyp(t) 394 if len(nds) == 0 { 395 continue 396 } 397 398 var xn *node 399 xsearch := search 400 401 var label byte 402 if search != "" { 403 label = search[0] 404 } 405 406 switch ntyp { 407 case ntStatic: 408 xn = nds.findEdge(label) 409 if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { 410 continue 411 } 412 xsearch = xsearch[len(xn.prefix):] 413 414 case ntParam, ntRegexp: 415 // short-circuit and return no matching route for empty param values 416 if xsearch == "" { 417 continue 418 } 419 420 // serially loop through each node grouped by the tail delimiter 421 for idx := 0; idx < len(nds); idx++ { 422 xn = nds[idx] 423 424 // label for param nodes is the delimiter byte 425 p := strings.IndexByte(xsearch, xn.tail) 426 427 if p < 0 { 428 if xn.tail == '/' { 429 p = len(xsearch) 430 } else { 431 continue 432 } 433 } 434 435 if ntyp == ntRegexp && xn.rex != nil { 436 if !xn.rex.MatchString(xsearch[:p]) { 437 continue 438 } 439 } else if strings.IndexByte(xsearch[:p], '/') != -1 { 440 // avoid a match across path segments 441 continue 442 } 443 444 prevlen := len(rctx.routeParams.Values) 445 rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) 446 xsearch = xsearch[p:] 447 448 if len(xsearch) == 0 { 449 if xn.isLeaf() { 450 h := xn.endpoints[method] 451 if h != nil && h.handler != nil { 452 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) 453 return xn 454 } 455 456 // flag that the routing context found a route, but not a corresponding 457 // supported method 458 rctx.methodNotAllowed = true 459 } 460 } 461 462 // recursively find the next node on this branch 463 fin := xn.findRoute(rctx, method, xsearch) 464 if fin != nil { 465 return fin 466 } 467 468 // not found on this branch, reset vars 469 rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] 470 xsearch = search 471 } 472 473 rctx.routeParams.Values = append(rctx.routeParams.Values, "") 474 475 default: 476 // catch-all nodes 477 rctx.routeParams.Values = append(rctx.routeParams.Values, search) 478 xn = nds[0] 479 xsearch = "" 480 } 481 482 if xn == nil { 483 continue 484 } 485 486 // did we find it yet? 487 if len(xsearch) == 0 { 488 if xn.isLeaf() { 489 h := xn.endpoints[method] 490 if h != nil && h.handler != nil { 491 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) 492 return xn 493 } 494 495 // flag that the routing context found a route, but not a corresponding 496 // supported method 497 rctx.methodNotAllowed = true 498 } 499 } 500 501 // recursively find the next node.. 502 fin := xn.findRoute(rctx, method, xsearch) 503 if fin != nil { 504 return fin 505 } 506 507 // Did not find final handler, let's remove the param here if it was set 508 if xn.typ > ntStatic { 509 if len(rctx.routeParams.Values) > 0 { 510 rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] 511 } 512 } 513 514 } 515 516 return nil 517} 518 519func (n *node) findEdge(ntyp nodeTyp, label byte) *node { 520 nds := n.children[ntyp] 521 num := len(nds) 522 idx := 0 523 524 switch ntyp { 525 case ntStatic, ntParam, ntRegexp: 526 i, j := 0, num-1 527 for i <= j { 528 idx = i + (j-i)/2 529 if label > nds[idx].label { 530 i = idx + 1 531 } else if label < nds[idx].label { 532 j = idx - 1 533 } else { 534 i = num // breaks cond 535 } 536 } 537 if nds[idx].label != label { 538 return nil 539 } 540 return nds[idx] 541 542 default: // catch all 543 return nds[idx] 544 } 545} 546 547func (n *node) isLeaf() bool { 548 return n.endpoints != nil 549} 550 551func (n *node) findPattern(pattern string) bool { 552 nn := n 553 for _, nds := range nn.children { 554 if len(nds) == 0 { 555 continue 556 } 557 558 n = nn.findEdge(nds[0].typ, pattern[0]) 559 if n == nil { 560 continue 561 } 562 563 var idx int 564 var xpattern string 565 566 switch n.typ { 567 case ntStatic: 568 idx = longestPrefix(pattern, n.prefix) 569 if idx < len(n.prefix) { 570 continue 571 } 572 573 case ntParam, ntRegexp: 574 idx = strings.IndexByte(pattern, '}') + 1 575 576 case ntCatchAll: 577 idx = longestPrefix(pattern, "*") 578 579 default: 580 panic("chi: unknown node type") 581 } 582 583 xpattern = pattern[idx:] 584 if len(xpattern) == 0 { 585 return true 586 } 587 588 return n.findPattern(xpattern) 589 } 590 return false 591} 592 593func (n *node) routes() []Route { 594 rts := []Route{} 595 596 n.walk(func(eps endpoints, subroutes Routes) bool { 597 if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { 598 return false 599 } 600 601 // Group methodHandlers by unique patterns 602 pats := make(map[string]endpoints) 603 604 for mt, h := range eps { 605 if h.pattern == "" { 606 continue 607 } 608 p, ok := pats[h.pattern] 609 if !ok { 610 p = endpoints{} 611 pats[h.pattern] = p 612 } 613 p[mt] = h 614 } 615 616 for p, mh := range pats { 617 hs := make(map[string]http.Handler) 618 if mh[mALL] != nil && mh[mALL].handler != nil { 619 hs["*"] = mh[mALL].handler 620 } 621 622 for mt, h := range mh { 623 if h.handler == nil { 624 continue 625 } 626 m := methodTypString(mt) 627 if m == "" { 628 continue 629 } 630 hs[m] = h.handler 631 } 632 633 rt := Route{p, hs, subroutes} 634 rts = append(rts, rt) 635 } 636 637 return false 638 }) 639 640 return rts 641} 642 643func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { 644 // Visit the leaf values if any 645 if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { 646 return true 647 } 648 649 // Recurse on the children 650 for _, ns := range n.children { 651 for _, cn := range ns { 652 if cn.walk(fn) { 653 return true 654 } 655 } 656 } 657 return false 658} 659 660// patNextSegment returns the next segment details from a pattern: 661// node type, param key, regexp string, param tail byte, param starting index, param ending index 662func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { 663 ps := strings.Index(pattern, "{") 664 ws := strings.Index(pattern, "*") 665 666 if ps < 0 && ws < 0 { 667 return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing 668 } 669 670 // Sanity check 671 if ps >= 0 && ws >= 0 && ws < ps { 672 panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") 673 } 674 675 var tail byte = '/' // Default endpoint tail to / byte 676 677 if ps >= 0 { 678 // Param/Regexp pattern is next 679 nt := ntParam 680 681 // Read to closing } taking into account opens and closes in curl count (cc) 682 cc := 0 683 pe := ps 684 for i, c := range pattern[ps:] { 685 if c == '{' { 686 cc++ 687 } else if c == '}' { 688 cc-- 689 if cc == 0 { 690 pe = ps + i 691 break 692 } 693 } 694 } 695 if pe == ps { 696 panic("chi: route param closing delimiter '}' is missing") 697 } 698 699 key := pattern[ps+1 : pe] 700 pe++ // set end to next position 701 702 if pe < len(pattern) { 703 tail = pattern[pe] 704 } 705 706 var rexpat string 707 if idx := strings.Index(key, ":"); idx >= 0 { 708 nt = ntRegexp 709 rexpat = key[idx+1:] 710 key = key[:idx] 711 } 712 713 if len(rexpat) > 0 { 714 if rexpat[0] != '^' { 715 rexpat = "^" + rexpat 716 } 717 if rexpat[len(rexpat)-1] != '$' { 718 rexpat += "$" 719 } 720 } 721 722 return nt, key, rexpat, tail, ps, pe 723 } 724 725 // Wildcard pattern as finale 726 if ws < len(pattern)-1 { 727 panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") 728 } 729 return ntCatchAll, "*", "", 0, ws, len(pattern) 730} 731 732func patParamKeys(pattern string) []string { 733 pat := pattern 734 paramKeys := []string{} 735 for { 736 ptyp, paramKey, _, _, _, e := patNextSegment(pat) 737 if ptyp == ntStatic { 738 return paramKeys 739 } 740 for i := 0; i < len(paramKeys); i++ { 741 if paramKeys[i] == paramKey { 742 panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) 743 } 744 } 745 paramKeys = append(paramKeys, paramKey) 746 pat = pat[e:] 747 } 748} 749 750// longestPrefix finds the length of the shared prefix 751// of two strings 752func longestPrefix(k1, k2 string) int { 753 max := len(k1) 754 if l := len(k2); l < max { 755 max = l 756 } 757 var i int 758 for i = 0; i < max; i++ { 759 if k1[i] != k2[i] { 760 break 761 } 762 } 763 return i 764} 765 766func methodTypString(method methodTyp) string { 767 for s, t := range methodMap { 768 if method == t { 769 return s 770 } 771 } 772 return "" 773} 774 775type nodes []*node 776 777// Sort the list of nodes by label 778func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } 779func (ns nodes) Len() int { return len(ns) } 780func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } 781func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } 782 783// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. 784// The list order determines the traversal order. 785func (ns nodes) tailSort() { 786 for i := len(ns) - 1; i >= 0; i-- { 787 if ns[i].typ > ntStatic && ns[i].tail == '/' { 788 ns.Swap(i, len(ns)-1) 789 return 790 } 791 } 792} 793 794func (ns nodes) findEdge(label byte) *node { 795 num := len(ns) 796 idx := 0 797 i, j := 0, num-1 798 for i <= j { 799 idx = i + (j-i)/2 800 if label > ns[idx].label { 801 i = idx + 1 802 } else if label < ns[idx].label { 803 j = idx - 1 804 } else { 805 i = num // breaks cond 806 } 807 } 808 if ns[idx].label != label { 809 return nil 810 } 811 return ns[idx] 812} 813 814// Route describes the details of a routing handler. 815// Handlers map key is an HTTP method 816type Route struct { 817 Pattern string 818 Handlers map[string]http.Handler 819 SubRoutes Routes 820} 821 822// WalkFunc is the type of the function called for each method and route visited by Walk. 823type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error 824 825// Walk walks any router tree that implements Routes interface. 826func Walk(r Routes, walkFn WalkFunc) error { 827 return walk(r, walkFn, "") 828} 829 830func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { 831 for _, route := range r.Routes() { 832 mws := make([]func(http.Handler) http.Handler, len(parentMw)) 833 copy(mws, parentMw) 834 mws = append(mws, r.Middlewares()...) 835 836 if route.SubRoutes != nil { 837 if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { 838 return err 839 } 840 continue 841 } 842 843 for method, handler := range route.Handlers { 844 if method == "*" { 845 // Ignore a "catchAll" method, since we pass down all the specific methods for each route. 846 continue 847 } 848 849 fullRoute := parentRoute + route.Pattern 850 fullRoute = strings.Replace(fullRoute, "/*/", "/", -1) 851 852 if chain, ok := handler.(*ChainHandler); ok { 853 if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { 854 return err 855 } 856 } else { 857 if err := walkFn(method, fullRoute, handler, mws...); err != nil { 858 return err 859 } 860 } 861 } 862 } 863 864 return nil 865} 866