1package chi 2 3// Radix tree implementation below is a based on the original work by 4// Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go 5// (MIT licensed). It's been heavily modified for use as a HTTP routing tree. 6 7import ( 8 "fmt" 9 "net/http" 10 "regexp" 11 "sort" 12 "strconv" 13 "strings" 14) 15 16type methodTyp uint 17 18const ( 19 mSTUB methodTyp = 1 << iota 20 mCONNECT 21 mDELETE 22 mGET 23 mHEAD 24 mOPTIONS 25 mPATCH 26 mPOST 27 mPUT 28 mTRACE 29) 30 31var mALL = mCONNECT | mDELETE | mGET | mHEAD | 32 mOPTIONS | mPATCH | mPOST | mPUT | mTRACE 33 34var methodMap = map[string]methodTyp{ 35 http.MethodConnect: mCONNECT, 36 http.MethodDelete: mDELETE, 37 http.MethodGet: mGET, 38 http.MethodHead: mHEAD, 39 http.MethodOptions: mOPTIONS, 40 http.MethodPatch: mPATCH, 41 http.MethodPost: mPOST, 42 http.MethodPut: mPUT, 43 http.MethodTrace: mTRACE, 44} 45 46// RegisterMethod adds support for custom HTTP method handlers, available 47// via Router#Method and Router#MethodFunc 48func RegisterMethod(method string) { 49 if method == "" { 50 return 51 } 52 method = strings.ToUpper(method) 53 if _, ok := methodMap[method]; ok { 54 return 55 } 56 n := len(methodMap) 57 if n > strconv.IntSize-2 { 58 panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize)) 59 } 60 mt := methodTyp(2 << n) 61 methodMap[method] = mt 62 mALL |= mt 63} 64 65type nodeTyp uint8 66 67const ( 68 ntStatic nodeTyp = iota // /home 69 ntRegexp // /{id:[0-9]+} 70 ntParam // /{user} 71 ntCatchAll // /api/v1/* 72) 73 74type node struct { 75 // subroutes on the leaf node 76 subroutes Routes 77 78 // regexp matcher for regexp nodes 79 rex *regexp.Regexp 80 81 // HTTP handler endpoints on the leaf node 82 endpoints endpoints 83 84 // prefix is the common prefix we ignore 85 prefix string 86 87 // child nodes should be stored in-order for iteration, 88 // in groups of the node type. 89 children [ntCatchAll + 1]nodes 90 91 // first byte of the child prefix 92 tail byte 93 94 // node type: static, regexp, param, catchAll 95 typ nodeTyp 96 97 // first byte of the prefix 98 label byte 99} 100 101// endpoints is a mapping of http method constants to handlers 102// for a given route. 103type endpoints map[methodTyp]*endpoint 104 105type endpoint struct { 106 // endpoint handler 107 handler http.Handler 108 109 // pattern is the routing pattern for handler nodes 110 pattern string 111 112 // parameter keys recorded on handler nodes 113 paramKeys []string 114} 115 116func (s endpoints) Value(method methodTyp) *endpoint { 117 mh, ok := s[method] 118 if !ok { 119 mh = &endpoint{} 120 s[method] = mh 121 } 122 return mh 123} 124 125func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { 126 var parent *node 127 search := pattern 128 129 for { 130 // Handle key exhaustion 131 if len(search) == 0 { 132 // Insert or update the node's leaf handler 133 n.setEndpoint(method, handler, pattern) 134 return n 135 } 136 137 // We're going to be searching for a wild node next, 138 // in this case, we need to get the tail 139 var label = search[0] 140 var segTail byte 141 var segEndIdx int 142 var segTyp nodeTyp 143 var segRexpat string 144 if label == '{' || label == '*' { 145 segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) 146 } 147 148 var prefix string 149 if segTyp == ntRegexp { 150 prefix = segRexpat 151 } 152 153 // Look for the edge to attach to 154 parent = n 155 n = n.getEdge(segTyp, label, segTail, prefix) 156 157 // No edge, create one 158 if n == nil { 159 child := &node{label: label, tail: segTail, prefix: search} 160 hn := parent.addChild(child, search) 161 hn.setEndpoint(method, handler, pattern) 162 163 return hn 164 } 165 166 // Found an edge to match the pattern 167 168 if n.typ > ntStatic { 169 // We found a param node, trim the param from the search path and continue. 170 // This param/wild pattern segment would already be on the tree from a previous 171 // call to addChild when creating a new node. 172 search = search[segEndIdx:] 173 continue 174 } 175 176 // Static nodes fall below here. 177 // Determine longest prefix of the search key on match. 178 commonPrefix := longestPrefix(search, n.prefix) 179 if commonPrefix == len(n.prefix) { 180 // the common prefix is as long as the current node's prefix we're attempting to insert. 181 // keep the search going. 182 search = search[commonPrefix:] 183 continue 184 } 185 186 // Split the node 187 child := &node{ 188 typ: ntStatic, 189 prefix: search[:commonPrefix], 190 } 191 parent.replaceChild(search[0], segTail, child) 192 193 // Restore the existing node 194 n.label = n.prefix[commonPrefix] 195 n.prefix = n.prefix[commonPrefix:] 196 child.addChild(n, n.prefix) 197 198 // If the new key is a subset, set the method/handler on this node and finish. 199 search = search[commonPrefix:] 200 if len(search) == 0 { 201 child.setEndpoint(method, handler, pattern) 202 return child 203 } 204 205 // Create a new edge for the node 206 subchild := &node{ 207 typ: ntStatic, 208 label: search[0], 209 prefix: search, 210 } 211 hn := child.addChild(subchild, search) 212 hn.setEndpoint(method, handler, pattern) 213 return hn 214 } 215} 216 217// addChild appends the new `child` node to the tree using the `pattern` as the trie key. 218// For a URL router like chi's, we split the static, param, regexp and wildcard segments 219// into different nodes. In addition, addChild will recursively call itself until every 220// pattern segment is added to the url pattern tree as individual nodes, depending on type. 221func (n *node) addChild(child *node, prefix string) *node { 222 search := prefix 223 224 // handler leaf node added to the tree is the child. 225 // this may be overridden later down the flow 226 hn := child 227 228 // Parse next segment 229 segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) 230 231 // Add child depending on next up segment 232 switch segTyp { 233 234 case ntStatic: 235 // Search prefix is all static (that is, has no params in path) 236 // noop 237 238 default: 239 // Search prefix contains a param, regexp or wildcard 240 241 if segTyp == ntRegexp { 242 rex, err := regexp.Compile(segRexpat) 243 if err != nil { 244 panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat)) 245 } 246 child.prefix = segRexpat 247 child.rex = rex 248 } 249 250 if segStartIdx == 0 { 251 // Route starts with a param 252 child.typ = segTyp 253 254 if segTyp == ntCatchAll { 255 segStartIdx = -1 256 } else { 257 segStartIdx = segEndIdx 258 } 259 if segStartIdx < 0 { 260 segStartIdx = len(search) 261 } 262 child.tail = segTail // for params, we set the tail 263 264 if segStartIdx != len(search) { 265 // add static edge for the remaining part, split the end. 266 // its not possible to have adjacent param nodes, so its certainly 267 // going to be a static node next. 268 269 search = search[segStartIdx:] // advance search position 270 271 nn := &node{ 272 typ: ntStatic, 273 label: search[0], 274 prefix: search, 275 } 276 hn = child.addChild(nn, search) 277 } 278 279 } else if segStartIdx > 0 { 280 // Route has some param 281 282 // starts with a static segment 283 child.typ = ntStatic 284 child.prefix = search[:segStartIdx] 285 child.rex = nil 286 287 // add the param edge node 288 search = search[segStartIdx:] 289 290 nn := &node{ 291 typ: segTyp, 292 label: search[0], 293 tail: segTail, 294 } 295 hn = child.addChild(nn, search) 296 297 } 298 } 299 300 n.children[child.typ] = append(n.children[child.typ], child) 301 n.children[child.typ].Sort() 302 return hn 303} 304 305func (n *node) replaceChild(label, tail byte, child *node) { 306 for i := 0; i < len(n.children[child.typ]); i++ { 307 if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { 308 n.children[child.typ][i] = child 309 n.children[child.typ][i].label = label 310 n.children[child.typ][i].tail = tail 311 return 312 } 313 } 314 panic("chi: replacing missing child") 315} 316 317func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { 318 nds := n.children[ntyp] 319 for i := 0; i < len(nds); i++ { 320 if nds[i].label == label && nds[i].tail == tail { 321 if ntyp == ntRegexp && nds[i].prefix != prefix { 322 continue 323 } 324 return nds[i] 325 } 326 } 327 return nil 328} 329 330func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { 331 // Set the handler for the method type on the node 332 if n.endpoints == nil { 333 n.endpoints = make(endpoints) 334 } 335 336 paramKeys := patParamKeys(pattern) 337 338 if method&mSTUB == mSTUB { 339 n.endpoints.Value(mSTUB).handler = handler 340 } 341 if method&mALL == mALL { 342 h := n.endpoints.Value(mALL) 343 h.handler = handler 344 h.pattern = pattern 345 h.paramKeys = paramKeys 346 for _, m := range methodMap { 347 h := n.endpoints.Value(m) 348 h.handler = handler 349 h.pattern = pattern 350 h.paramKeys = paramKeys 351 } 352 } else { 353 h := n.endpoints.Value(method) 354 h.handler = handler 355 h.pattern = pattern 356 h.paramKeys = paramKeys 357 } 358} 359 360func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { 361 // Reset the context routing pattern and params 362 rctx.routePattern = "" 363 rctx.routeParams.Keys = rctx.routeParams.Keys[:0] 364 rctx.routeParams.Values = rctx.routeParams.Values[:0] 365 366 // Find the routing handlers for the path 367 rn := n.findRoute(rctx, method, path) 368 if rn == nil { 369 return nil, nil, nil 370 } 371 372 // Record the routing params in the request lifecycle 373 rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) 374 rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) 375 376 // Record the routing pattern in the request lifecycle 377 if rn.endpoints[method].pattern != "" { 378 rctx.routePattern = rn.endpoints[method].pattern 379 rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) 380 } 381 382 return rn, rn.endpoints, rn.endpoints[method].handler 383} 384 385// Recursive edge traversal by checking all nodeTyp groups along the way. 386// It's like searching through a multi-dimensional radix trie. 387func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { 388 nn := n 389 search := path 390 391 for t, nds := range nn.children { 392 ntyp := nodeTyp(t) 393 if len(nds) == 0 { 394 continue 395 } 396 397 var xn *node 398 xsearch := search 399 400 var label byte 401 if search != "" { 402 label = search[0] 403 } 404 405 switch ntyp { 406 case ntStatic: 407 xn = nds.findEdge(label) 408 if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { 409 continue 410 } 411 xsearch = xsearch[len(xn.prefix):] 412 413 case ntParam, ntRegexp: 414 // short-circuit and return no matching route for empty param values 415 if xsearch == "" { 416 continue 417 } 418 419 // serially loop through each node grouped by the tail delimiter 420 for idx := 0; idx < len(nds); idx++ { 421 xn = nds[idx] 422 423 // label for param nodes is the delimiter byte 424 p := strings.IndexByte(xsearch, xn.tail) 425 426 if p < 0 { 427 if xn.tail == '/' { 428 p = len(xsearch) 429 } else { 430 continue 431 } 432 } else if ntyp == ntRegexp && p == 0 { 433 continue 434 } 435 436 if ntyp == ntRegexp && xn.rex != nil { 437 if !xn.rex.MatchString(xsearch[:p]) { 438 continue 439 } 440 } else if strings.IndexByte(xsearch[:p], '/') != -1 { 441 // avoid a match across path segments 442 continue 443 } 444 445 prevlen := len(rctx.routeParams.Values) 446 rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) 447 xsearch = xsearch[p:] 448 449 if len(xsearch) == 0 { 450 if xn.isLeaf() { 451 h := xn.endpoints[method] 452 if h != nil && h.handler != nil { 453 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) 454 return xn 455 } 456 457 // flag that the routing context found a route, but not a corresponding 458 // supported method 459 rctx.methodNotAllowed = true 460 } 461 } 462 463 // recursively find the next node on this branch 464 fin := xn.findRoute(rctx, method, xsearch) 465 if fin != nil { 466 return fin 467 } 468 469 // not found on this branch, reset vars 470 rctx.routeParams.Values = rctx.routeParams.Values[:prevlen] 471 xsearch = search 472 } 473 474 rctx.routeParams.Values = append(rctx.routeParams.Values, "") 475 476 default: 477 // catch-all nodes 478 rctx.routeParams.Values = append(rctx.routeParams.Values, search) 479 xn = nds[0] 480 xsearch = "" 481 } 482 483 if xn == nil { 484 continue 485 } 486 487 // did we find it yet? 488 if len(xsearch) == 0 { 489 if xn.isLeaf() { 490 h := xn.endpoints[method] 491 if h != nil && h.handler != nil { 492 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) 493 return xn 494 } 495 496 // flag that the routing context found a route, but not a corresponding 497 // supported method 498 rctx.methodNotAllowed = true 499 } 500 } 501 502 // recursively find the next node.. 503 fin := xn.findRoute(rctx, method, xsearch) 504 if fin != nil { 505 return fin 506 } 507 508 // Did not find final handler, let's remove the param here if it was set 509 if xn.typ > ntStatic { 510 if len(rctx.routeParams.Values) > 0 { 511 rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] 512 } 513 } 514 515 } 516 517 return nil 518} 519 520func (n *node) findEdge(ntyp nodeTyp, label byte) *node { 521 nds := n.children[ntyp] 522 num := len(nds) 523 idx := 0 524 525 switch ntyp { 526 case ntStatic, ntParam, ntRegexp: 527 i, j := 0, num-1 528 for i <= j { 529 idx = i + (j-i)/2 530 if label > nds[idx].label { 531 i = idx + 1 532 } else if label < nds[idx].label { 533 j = idx - 1 534 } else { 535 i = num // breaks cond 536 } 537 } 538 if nds[idx].label != label { 539 return nil 540 } 541 return nds[idx] 542 543 default: // catch all 544 return nds[idx] 545 } 546} 547 548func (n *node) isLeaf() bool { 549 return n.endpoints != nil 550} 551 552func (n *node) findPattern(pattern string) bool { 553 nn := n 554 for _, nds := range nn.children { 555 if len(nds) == 0 { 556 continue 557 } 558 559 n = nn.findEdge(nds[0].typ, pattern[0]) 560 if n == nil { 561 continue 562 } 563 564 var idx int 565 var xpattern string 566 567 switch n.typ { 568 case ntStatic: 569 idx = longestPrefix(pattern, n.prefix) 570 if idx < len(n.prefix) { 571 continue 572 } 573 574 case ntParam, ntRegexp: 575 idx = strings.IndexByte(pattern, '}') + 1 576 577 case ntCatchAll: 578 idx = longestPrefix(pattern, "*") 579 580 default: 581 panic("chi: unknown node type") 582 } 583 584 xpattern = pattern[idx:] 585 if len(xpattern) == 0 { 586 return true 587 } 588 589 return n.findPattern(xpattern) 590 } 591 return false 592} 593 594func (n *node) routes() []Route { 595 rts := []Route{} 596 597 n.walk(func(eps endpoints, subroutes Routes) bool { 598 if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { 599 return false 600 } 601 602 // Group methodHandlers by unique patterns 603 pats := make(map[string]endpoints) 604 605 for mt, h := range eps { 606 if h.pattern == "" { 607 continue 608 } 609 p, ok := pats[h.pattern] 610 if !ok { 611 p = endpoints{} 612 pats[h.pattern] = p 613 } 614 p[mt] = h 615 } 616 617 for p, mh := range pats { 618 hs := make(map[string]http.Handler) 619 if mh[mALL] != nil && mh[mALL].handler != nil { 620 hs["*"] = mh[mALL].handler 621 } 622 623 for mt, h := range mh { 624 if h.handler == nil { 625 continue 626 } 627 m := methodTypString(mt) 628 if m == "" { 629 continue 630 } 631 hs[m] = h.handler 632 } 633 634 rt := Route{subroutes, hs, p} 635 rts = append(rts, rt) 636 } 637 638 return false 639 }) 640 641 return rts 642} 643 644func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { 645 // Visit the leaf values if any 646 if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { 647 return true 648 } 649 650 // Recurse on the children 651 for _, ns := range n.children { 652 for _, cn := range ns { 653 if cn.walk(fn) { 654 return true 655 } 656 } 657 } 658 return false 659} 660 661// patNextSegment returns the next segment details from a pattern: 662// node type, param key, regexp string, param tail byte, param starting index, param ending index 663func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { 664 ps := strings.Index(pattern, "{") 665 ws := strings.Index(pattern, "*") 666 667 if ps < 0 && ws < 0 { 668 return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing 669 } 670 671 // Sanity check 672 if ps >= 0 && ws >= 0 && ws < ps { 673 panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") 674 } 675 676 var tail byte = '/' // Default endpoint tail to / byte 677 678 if ps >= 0 { 679 // Param/Regexp pattern is next 680 nt := ntParam 681 682 // Read to closing } taking into account opens and closes in curl count (cc) 683 cc := 0 684 pe := ps 685 for i, c := range pattern[ps:] { 686 if c == '{' { 687 cc++ 688 } else if c == '}' { 689 cc-- 690 if cc == 0 { 691 pe = ps + i 692 break 693 } 694 } 695 } 696 if pe == ps { 697 panic("chi: route param closing delimiter '}' is missing") 698 } 699 700 key := pattern[ps+1 : pe] 701 pe++ // set end to next position 702 703 if pe < len(pattern) { 704 tail = pattern[pe] 705 } 706 707 var rexpat string 708 if idx := strings.Index(key, ":"); idx >= 0 { 709 nt = ntRegexp 710 rexpat = key[idx+1:] 711 key = key[:idx] 712 } 713 714 if len(rexpat) > 0 { 715 if rexpat[0] != '^' { 716 rexpat = "^" + rexpat 717 } 718 if rexpat[len(rexpat)-1] != '$' { 719 rexpat += "$" 720 } 721 } 722 723 return nt, key, rexpat, tail, ps, pe 724 } 725 726 // Wildcard pattern as finale 727 if ws < len(pattern)-1 { 728 panic("chi: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead") 729 } 730 return ntCatchAll, "*", "", 0, ws, len(pattern) 731} 732 733func patParamKeys(pattern string) []string { 734 pat := pattern 735 paramKeys := []string{} 736 for { 737 ptyp, paramKey, _, _, _, e := patNextSegment(pat) 738 if ptyp == ntStatic { 739 return paramKeys 740 } 741 for i := 0; i < len(paramKeys); i++ { 742 if paramKeys[i] == paramKey { 743 panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) 744 } 745 } 746 paramKeys = append(paramKeys, paramKey) 747 pat = pat[e:] 748 } 749} 750 751// longestPrefix finds the length of the shared prefix 752// of two strings 753func longestPrefix(k1, k2 string) int { 754 max := len(k1) 755 if l := len(k2); l < max { 756 max = l 757 } 758 var i int 759 for i = 0; i < max; i++ { 760 if k1[i] != k2[i] { 761 break 762 } 763 } 764 return i 765} 766 767func methodTypString(method methodTyp) string { 768 for s, t := range methodMap { 769 if method == t { 770 return s 771 } 772 } 773 return "" 774} 775 776type nodes []*node 777 778// Sort the list of nodes by label 779func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } 780func (ns nodes) Len() int { return len(ns) } 781func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } 782func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } 783 784// tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. 785// The list order determines the traversal order. 786func (ns nodes) tailSort() { 787 for i := len(ns) - 1; i >= 0; i-- { 788 if ns[i].typ > ntStatic && ns[i].tail == '/' { 789 ns.Swap(i, len(ns)-1) 790 return 791 } 792 } 793} 794 795func (ns nodes) findEdge(label byte) *node { 796 num := len(ns) 797 idx := 0 798 i, j := 0, num-1 799 for i <= j { 800 idx = i + (j-i)/2 801 if label > ns[idx].label { 802 i = idx + 1 803 } else if label < ns[idx].label { 804 j = idx - 1 805 } else { 806 i = num // breaks cond 807 } 808 } 809 if ns[idx].label != label { 810 return nil 811 } 812 return ns[idx] 813} 814 815// Route describes the details of a routing handler. 816// Handlers map key is an HTTP method 817type Route struct { 818 SubRoutes Routes 819 Handlers map[string]http.Handler 820 Pattern string 821} 822 823// WalkFunc is the type of the function called for each method and route visited by Walk. 824type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error 825 826// Walk walks any router tree that implements Routes interface. 827func Walk(r Routes, walkFn WalkFunc) error { 828 return walk(r, walkFn, "") 829} 830 831func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { 832 for _, route := range r.Routes() { 833 mws := make([]func(http.Handler) http.Handler, len(parentMw)) 834 copy(mws, parentMw) 835 mws = append(mws, r.Middlewares()...) 836 837 if route.SubRoutes != nil { 838 if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { 839 return err 840 } 841 continue 842 } 843 844 for method, handler := range route.Handlers { 845 if method == "*" { 846 // Ignore a "catchAll" method, since we pass down all the specific methods for each route. 847 continue 848 } 849 850 fullRoute := parentRoute + route.Pattern 851 fullRoute = strings.Replace(fullRoute, "/*/", "/", -1) 852 853 if chain, ok := handler.(*ChainHandler); ok { 854 if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { 855 return err 856 } 857 } else { 858 if err := walkFn(method, fullRoute, handler, mws...); err != nil { 859 return err 860 } 861 } 862 } 863 } 864 865 return nil 866} 867