1// Package gorillamux implements a router. 2// 3// It differs from the legacy router: 4// * it provides somewhat granular errors: "path not found", "method not allowed". 5// * it handles matching routes with extensions (e.g. /books/{id}.json) 6// * it handles path patterns with a different syntax (e.g. /params/{x}/{y}/{z:.*}) 7package gorillamux 8 9import ( 10 "net/http" 11 "net/url" 12 "sort" 13 "strings" 14 15 "github.com/getkin/kin-openapi/openapi3" 16 "github.com/getkin/kin-openapi/routers" 17 "github.com/gorilla/mux" 18) 19 20// Router helps link http.Request.s and an OpenAPIv3 spec 21type Router struct { 22 muxes []*mux.Route 23 routes []*routers.Route 24} 25 26// NewRouter creates a gorilla/mux router. 27// Assumes spec is .Validate()d 28// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request)) 29func NewRouter(doc *openapi3.T) (routers.Router, error) { 30 type srv struct { 31 schemes []string 32 host, base string 33 server *openapi3.Server 34 } 35 servers := make([]srv, 0, len(doc.Servers)) 36 for _, server := range doc.Servers { 37 serverURL := server.URL 38 var schemes []string 39 var u *url.URL 40 var err error 41 if strings.Contains(serverURL, "://") { 42 scheme0 := strings.Split(serverURL, "://")[0] 43 schemes = permutePart(scheme0, server) 44 u, err = url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1))) 45 } else { 46 u, err = url.Parse(bEncode(serverURL)) 47 } 48 if err != nil { 49 return nil, err 50 } 51 path := bDecode(u.EscapedPath()) 52 if len(path) > 0 && path[len(path)-1] == '/' { 53 path = path[:len(path)-1] 54 } 55 servers = append(servers, srv{ 56 host: bDecode(u.Host), //u.Hostname()? 57 base: path, 58 schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624 59 server: server, 60 }) 61 } 62 if len(servers) == 0 { 63 servers = append(servers, srv{}) 64 } 65 muxRouter := mux.NewRouter().UseEncodedPath() 66 r := &Router{} 67 for _, path := range orderedPaths(doc.Paths) { 68 pathItem := doc.Paths[path] 69 70 operations := pathItem.Operations() 71 methods := make([]string, 0, len(operations)) 72 for method := range operations { 73 methods = append(methods, method) 74 } 75 sort.Strings(methods) 76 77 for _, s := range servers { 78 muxRoute := muxRouter.Path(s.base + path).Methods(methods...) 79 if schemes := s.schemes; len(schemes) != 0 { 80 muxRoute.Schemes(schemes...) 81 } 82 if host := s.host; host != "" { 83 muxRoute.Host(host) 84 } 85 if err := muxRoute.GetError(); err != nil { 86 return nil, err 87 } 88 r.muxes = append(r.muxes, muxRoute) 89 r.routes = append(r.routes, &routers.Route{ 90 Spec: doc, 91 Server: s.server, 92 Path: path, 93 PathItem: pathItem, 94 Method: "", 95 Operation: nil, 96 }) 97 } 98 } 99 return r, nil 100} 101 102// FindRoute extracts the route and parameters of an http.Request 103func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string, error) { 104 for i, muxRoute := range r.muxes { 105 var match mux.RouteMatch 106 if muxRoute.Match(req, &match) { 107 if err := match.MatchErr; err != nil { 108 // What then? 109 } 110 route := r.routes[i] 111 route.Method = req.Method 112 route.Operation = route.Spec.Paths[route.Path].GetOperation(route.Method) 113 return route, match.Vars, nil 114 } 115 switch match.MatchErr { 116 case nil: 117 case mux.ErrMethodMismatch: 118 return nil, nil, routers.ErrMethodNotAllowed 119 default: // What then? 120 } 121 } 122 return nil, nil, routers.ErrPathNotFound 123} 124 125func orderedPaths(paths map[string]*openapi3.PathItem) []string { 126 // https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject 127 // When matching URLs, concrete (non-templated) paths would be matched 128 // before their templated counterparts. 129 // NOTE: sorting by number of variables ASC then by lexicographical 130 // order seems to be a good heuristic. 131 vars := make(map[int][]string) 132 max := 0 133 for path := range paths { 134 count := strings.Count(path, "}") 135 vars[count] = append(vars[count], path) 136 if count > max { 137 max = count 138 } 139 } 140 ordered := make([]string, 0, len(paths)) 141 for c := 0; c <= max; c++ { 142 if ps, ok := vars[c]; ok { 143 sort.Strings(ps) 144 ordered = append(ordered, ps...) 145 } 146 } 147 return ordered 148} 149 150// Magic strings that temporarily replace "{}" so net/url.Parse() works 151var blURL, brURL = strings.Repeat("-", 50), strings.Repeat("_", 50) 152 153func bEncode(s string) string { 154 s = strings.Replace(s, "{", blURL, -1) 155 s = strings.Replace(s, "}", brURL, -1) 156 return s 157} 158func bDecode(s string) string { 159 s = strings.Replace(s, blURL, "{", -1) 160 s = strings.Replace(s, brURL, "}", -1) 161 return s 162} 163 164func permutePart(part0 string, srv *openapi3.Server) []string { 165 type mapAndSlice struct { 166 m map[string]struct{} 167 s []string 168 } 169 var2val := make(map[string]mapAndSlice) 170 max := 0 171 for name0, v := range srv.Variables { 172 name := "{" + name0 + "}" 173 if !strings.Contains(part0, name) { 174 continue 175 } 176 m := map[string]struct{}{v.Default: {}} 177 for _, value := range v.Enum { 178 m[value] = struct{}{} 179 } 180 if l := len(m); l > max { 181 max = l 182 } 183 s := make([]string, 0, len(m)) 184 for value := range m { 185 s = append(s, value) 186 } 187 var2val[name] = mapAndSlice{m: m, s: s} 188 } 189 if len(var2val) == 0 { 190 return []string{part0} 191 } 192 193 partsMap := make(map[string]struct{}, max*len(var2val)) 194 for i := 0; i < max; i++ { 195 part := part0 196 for name, mas := range var2val { 197 part = strings.Replace(part, name, mas.s[i%len(mas.s)], -1) 198 } 199 partsMap[part] = struct{}{} 200 } 201 parts := make([]string, 0, len(partsMap)) 202 for part := range partsMap { 203 parts = append(parts, part) 204 } 205 sort.Strings(parts) 206 return parts 207} 208