1// Package gorillamux implements a router.
2//
3// It differs from the legacy router:
4// * it provides somewhat granular errors: "path not found", "method not allowed".
5// * it handles matching routes with extensions (e.g. /books/{id}.json)
6// * it handles path patterns with a different syntax (e.g. /params/{x}/{y}/{z:.*})
7package gorillamux
8
9import (
10	"net/http"
11	"net/url"
12	"sort"
13	"strings"
14
15	"github.com/getkin/kin-openapi/openapi3"
16	"github.com/getkin/kin-openapi/routers"
17	"github.com/gorilla/mux"
18)
19
20// Router helps link http.Request.s and an OpenAPIv3 spec
21type Router struct {
22	muxes  []*mux.Route
23	routes []*routers.Route
24}
25
26// NewRouter creates a gorilla/mux router.
27// Assumes spec is .Validate()d
28// TODO: Handle/HandlerFunc + ServeHTTP (When there is a match, the route variables can be retrieved calling mux.Vars(request))
29func NewRouter(doc *openapi3.T) (routers.Router, error) {
30	type srv struct {
31		schemes    []string
32		host, base string
33		server     *openapi3.Server
34	}
35	servers := make([]srv, 0, len(doc.Servers))
36	for _, server := range doc.Servers {
37		serverURL := server.URL
38		var schemes []string
39		var u *url.URL
40		var err error
41		if strings.Contains(serverURL, "://") {
42			scheme0 := strings.Split(serverURL, "://")[0]
43			schemes = permutePart(scheme0, server)
44			u, err = url.Parse(bEncode(strings.Replace(serverURL, scheme0+"://", schemes[0]+"://", 1)))
45		} else {
46			u, err = url.Parse(bEncode(serverURL))
47		}
48		if err != nil {
49			return nil, err
50		}
51		path := bDecode(u.EscapedPath())
52		if len(path) > 0 && path[len(path)-1] == '/' {
53			path = path[:len(path)-1]
54		}
55		servers = append(servers, srv{
56			host:    bDecode(u.Host), //u.Hostname()?
57			base:    path,
58			schemes: schemes, // scheme: []string{scheme0}, TODO: https://github.com/gorilla/mux/issues/624
59			server:  server,
60		})
61	}
62	if len(servers) == 0 {
63		servers = append(servers, srv{})
64	}
65	muxRouter := mux.NewRouter().UseEncodedPath()
66	r := &Router{}
67	for _, path := range orderedPaths(doc.Paths) {
68		pathItem := doc.Paths[path]
69
70		operations := pathItem.Operations()
71		methods := make([]string, 0, len(operations))
72		for method := range operations {
73			methods = append(methods, method)
74		}
75		sort.Strings(methods)
76
77		for _, s := range servers {
78			muxRoute := muxRouter.Path(s.base + path).Methods(methods...)
79			if schemes := s.schemes; len(schemes) != 0 {
80				muxRoute.Schemes(schemes...)
81			}
82			if host := s.host; host != "" {
83				muxRoute.Host(host)
84			}
85			if err := muxRoute.GetError(); err != nil {
86				return nil, err
87			}
88			r.muxes = append(r.muxes, muxRoute)
89			r.routes = append(r.routes, &routers.Route{
90				Spec:      doc,
91				Server:    s.server,
92				Path:      path,
93				PathItem:  pathItem,
94				Method:    "",
95				Operation: nil,
96			})
97		}
98	}
99	return r, nil
100}
101
102// FindRoute extracts the route and parameters of an http.Request
103func (r *Router) FindRoute(req *http.Request) (*routers.Route, map[string]string, error) {
104	for i, muxRoute := range r.muxes {
105		var match mux.RouteMatch
106		if muxRoute.Match(req, &match) {
107			if err := match.MatchErr; err != nil {
108				// What then?
109			}
110			route := r.routes[i]
111			route.Method = req.Method
112			route.Operation = route.Spec.Paths[route.Path].GetOperation(route.Method)
113			return route, match.Vars, nil
114		}
115		switch match.MatchErr {
116		case nil:
117		case mux.ErrMethodMismatch:
118			return nil, nil, routers.ErrMethodNotAllowed
119		default: // What then?
120		}
121	}
122	return nil, nil, routers.ErrPathNotFound
123}
124
125func orderedPaths(paths map[string]*openapi3.PathItem) []string {
126	// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#pathsObject
127	// When matching URLs, concrete (non-templated) paths would be matched
128	// before their templated counterparts.
129	// NOTE: sorting by number of variables ASC then by lexicographical
130	// order seems to be a good heuristic.
131	vars := make(map[int][]string)
132	max := 0
133	for path := range paths {
134		count := strings.Count(path, "}")
135		vars[count] = append(vars[count], path)
136		if count > max {
137			max = count
138		}
139	}
140	ordered := make([]string, 0, len(paths))
141	for c := 0; c <= max; c++ {
142		if ps, ok := vars[c]; ok {
143			sort.Strings(ps)
144			ordered = append(ordered, ps...)
145		}
146	}
147	return ordered
148}
149
150// Magic strings that temporarily replace "{}" so net/url.Parse() works
151var blURL, brURL = strings.Repeat("-", 50), strings.Repeat("_", 50)
152
153func bEncode(s string) string {
154	s = strings.Replace(s, "{", blURL, -1)
155	s = strings.Replace(s, "}", brURL, -1)
156	return s
157}
158func bDecode(s string) string {
159	s = strings.Replace(s, blURL, "{", -1)
160	s = strings.Replace(s, brURL, "}", -1)
161	return s
162}
163
164func permutePart(part0 string, srv *openapi3.Server) []string {
165	type mapAndSlice struct {
166		m map[string]struct{}
167		s []string
168	}
169	var2val := make(map[string]mapAndSlice)
170	max := 0
171	for name0, v := range srv.Variables {
172		name := "{" + name0 + "}"
173		if !strings.Contains(part0, name) {
174			continue
175		}
176		m := map[string]struct{}{v.Default: {}}
177		for _, value := range v.Enum {
178			m[value] = struct{}{}
179		}
180		if l := len(m); l > max {
181			max = l
182		}
183		s := make([]string, 0, len(m))
184		for value := range m {
185			s = append(s, value)
186		}
187		var2val[name] = mapAndSlice{m: m, s: s}
188	}
189	if len(var2val) == 0 {
190		return []string{part0}
191	}
192
193	partsMap := make(map[string]struct{}, max*len(var2val))
194	for i := 0; i < max; i++ {
195		part := part0
196		for name, mas := range var2val {
197			part = strings.Replace(part, name, mas.s[i%len(mas.s)], -1)
198		}
199		partsMap[part] = struct{}{}
200	}
201	parts := make([]string, 0, len(partsMap))
202	for part := range partsMap {
203		parts = append(parts, part)
204	}
205	sort.Strings(parts)
206	return parts
207}
208