1// Copyright 2015 go-swagger maintainers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package middleware
16
17import (
18	stdContext "context"
19	"fmt"
20	"net/http"
21	"strings"
22	"sync"
23
24	"github.com/go-openapi/runtime/security"
25
26	"github.com/go-openapi/analysis"
27	"github.com/go-openapi/errors"
28	"github.com/go-openapi/loads"
29	"github.com/go-openapi/runtime"
30	"github.com/go-openapi/runtime/logger"
31	"github.com/go-openapi/runtime/middleware/untyped"
32	"github.com/go-openapi/spec"
33	"github.com/go-openapi/strfmt"
34)
35
36// Debug when true turns on verbose logging
37var Debug = logger.DebugEnabled()
38var Logger logger.Logger = logger.StandardLogger{}
39
40func debugLog(format string, args ...interface{}) {
41	if Debug {
42		Logger.Printf(format, args...)
43	}
44}
45
46// A Builder can create middlewares
47type Builder func(http.Handler) http.Handler
48
49// PassthroughBuilder returns the handler, aka the builder identity function
50func PassthroughBuilder(handler http.Handler) http.Handler { return handler }
51
52// RequestBinder is an interface for types to implement
53// when they want to be able to bind from a request
54type RequestBinder interface {
55	BindRequest(*http.Request, *MatchedRoute) error
56}
57
58// Responder is an interface for types to implement
59// when they want to be considered for writing HTTP responses
60type Responder interface {
61	WriteResponse(http.ResponseWriter, runtime.Producer)
62}
63
64// ResponderFunc wraps a func as a Responder interface
65type ResponderFunc func(http.ResponseWriter, runtime.Producer)
66
67// WriteResponse writes to the response
68func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) {
69	fn(rw, pr)
70}
71
72// Context is a type safe wrapper around an untyped request context
73// used throughout to store request context with the standard context attached
74// to the http.Request
75type Context struct {
76	spec     *loads.Document
77	analyzer *analysis.Spec
78	api      RoutableAPI
79	router   Router
80}
81
82type routableUntypedAPI struct {
83	api             *untyped.API
84	hlock           *sync.Mutex
85	handlers        map[string]map[string]http.Handler
86	defaultConsumes string
87	defaultProduces string
88}
89
90func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI {
91	var handlers map[string]map[string]http.Handler
92	if spec == nil || api == nil {
93		return nil
94	}
95	analyzer := analysis.New(spec.Spec())
96	for method, hls := range analyzer.Operations() {
97		um := strings.ToUpper(method)
98		for path, op := range hls {
99			schemes := analyzer.SecurityRequirementsFor(op)
100
101			if oh, ok := api.OperationHandlerFor(method, path); ok {
102				if handlers == nil {
103					handlers = make(map[string]map[string]http.Handler)
104				}
105				if b, ok := handlers[um]; !ok || b == nil {
106					handlers[um] = make(map[string]http.Handler)
107				}
108
109				var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
110					// lookup route info in the context
111					route, rCtx, _ := context.RouteInfo(r)
112					if rCtx != nil {
113						r = rCtx
114					}
115
116					// bind and validate the request using reflection
117					var bound interface{}
118					var validation error
119					bound, r, validation = context.BindAndValidate(r, route)
120					if validation != nil {
121						context.Respond(w, r, route.Produces, route, validation)
122						return
123					}
124
125					// actually handle the request
126					result, err := oh.Handle(bound)
127					if err != nil {
128						// respond with failure
129						context.Respond(w, r, route.Produces, route, err)
130						return
131					}
132
133					// respond with success
134					context.Respond(w, r, route.Produces, route, result)
135				})
136
137				if len(schemes) > 0 {
138					handler = newSecureAPI(context, handler)
139				}
140				handlers[um][path] = handler
141			}
142		}
143	}
144
145	return &routableUntypedAPI{
146		api:             api,
147		hlock:           new(sync.Mutex),
148		handlers:        handlers,
149		defaultProduces: api.DefaultProduces,
150		defaultConsumes: api.DefaultConsumes,
151	}
152}
153
154func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) {
155	r.hlock.Lock()
156	paths, ok := r.handlers[strings.ToUpper(method)]
157	if !ok {
158		r.hlock.Unlock()
159		return nil, false
160	}
161	handler, ok := paths[path]
162	r.hlock.Unlock()
163	return handler, ok
164}
165func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) {
166	return r.api.ServeError
167}
168func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer {
169	return r.api.ConsumersFor(mediaTypes)
170}
171func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer {
172	return r.api.ProducersFor(mediaTypes)
173}
174func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator {
175	return r.api.AuthenticatorsFor(schemes)
176}
177func (r *routableUntypedAPI) Authorizer() runtime.Authorizer {
178	return r.api.Authorizer()
179}
180func (r *routableUntypedAPI) Formats() strfmt.Registry {
181	return r.api.Formats()
182}
183
184func (r *routableUntypedAPI) DefaultProduces() string {
185	return r.defaultProduces
186}
187
188func (r *routableUntypedAPI) DefaultConsumes() string {
189	return r.defaultConsumes
190}
191
192// NewRoutableContext creates a new context for a routable API
193func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context {
194	var an *analysis.Spec
195	if spec != nil {
196		an = analysis.New(spec.Spec())
197	}
198	ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes}
199	return ctx
200}
201
202// NewContext creates a new context wrapper
203func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context {
204	var an *analysis.Spec
205	if spec != nil {
206		an = analysis.New(spec.Spec())
207	}
208	ctx := &Context{spec: spec, analyzer: an}
209	ctx.api = newRoutableUntypedAPI(spec, api, ctx)
210	ctx.router = routes
211	return ctx
212}
213
214// Serve serves the specified spec with the specified api registrations as a http.Handler
215func Serve(spec *loads.Document, api *untyped.API) http.Handler {
216	return ServeWithBuilder(spec, api, PassthroughBuilder)
217}
218
219// ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated
220// by the Builder
221func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler {
222	context := NewContext(spec, api, nil)
223	return context.APIHandler(builder)
224}
225
226type contextKey int8
227
228const (
229	_ contextKey = iota
230	ctxContentType
231	ctxResponseFormat
232	ctxMatchedRoute
233	ctxBoundParams
234	ctxSecurityPrincipal
235	ctxSecurityScopes
236)
237
238// MatchedRouteFrom request context value.
239func MatchedRouteFrom(req *http.Request) *MatchedRoute {
240	mr := req.Context().Value(ctxMatchedRoute)
241	if mr == nil {
242		return nil
243	}
244	if res, ok := mr.(*MatchedRoute); ok {
245		return res
246	}
247	return nil
248}
249
250// SecurityPrincipalFrom request context value.
251func SecurityPrincipalFrom(req *http.Request) interface{} {
252	return req.Context().Value(ctxSecurityPrincipal)
253}
254
255// SecurityScopesFrom request context value.
256func SecurityScopesFrom(req *http.Request) []string {
257	rs := req.Context().Value(ctxSecurityScopes)
258	if res, ok := rs.([]string); ok {
259		return res
260	}
261	return nil
262}
263
264type contentTypeValue struct {
265	MediaType string
266	Charset   string
267}
268
269// BasePath returns the base path for this API
270func (c *Context) BasePath() string {
271	return c.spec.BasePath()
272}
273
274// RequiredProduces returns the accepted content types for responses
275func (c *Context) RequiredProduces() []string {
276	return c.analyzer.RequiredProduces()
277}
278
279// BindValidRequest binds a params object to a request but only when the request is valid
280// if the request is not valid an error will be returned
281func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error {
282	var res []error
283
284	requestContentType := "*/*"
285	// check and validate content type, select consumer
286	if runtime.HasBody(request) {
287		ct, _, err := runtime.ContentType(request.Header)
288		if err != nil {
289			res = append(res, err)
290		} else {
291			if err := validateContentType(route.Consumes, ct); err != nil {
292				res = append(res, err)
293			}
294			if len(res) == 0 {
295				cons, ok := route.Consumers[ct]
296				if !ok {
297					res = append(res, errors.New(500, "no consumer registered for %s", ct))
298				} else {
299					route.Consumer = cons
300					requestContentType = ct
301				}
302			}
303		}
304	}
305
306	// check and validate the response format
307	if len(res) == 0 && runtime.HasBody(request) {
308		if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" {
309			res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces))
310		}
311	}
312
313	// now bind the request with the provided binder
314	// it's assumed the binder will also validate the request and return an error if the
315	// request is invalid
316	if binder != nil && len(res) == 0 {
317		if err := binder.BindRequest(request, route); err != nil {
318			return err
319		}
320	}
321
322	if len(res) > 0 {
323		return errors.CompositeValidationError(res...)
324	}
325	return nil
326}
327
328// ContentType gets the parsed value of a content type
329// Returns the media type, its charset and a shallow copy of the request
330// when its context doesn't contain the content type value, otherwise it returns
331// the same request
332// Returns the error that runtime.ContentType may retunrs.
333func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) {
334	var rCtx = request.Context()
335
336	if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok {
337		return v.MediaType, v.Charset, request, nil
338	}
339
340	mt, cs, err := runtime.ContentType(request.Header)
341	if err != nil {
342		return "", "", nil, err
343	}
344	rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs})
345	return mt, cs, request.WithContext(rCtx), nil
346}
347
348// LookupRoute looks a route up and returns true when it is found
349func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) {
350	if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok {
351		return route, ok
352	}
353	return nil, false
354}
355
356// RouteInfo tries to match a route for this request
357// Returns the matched route, a shallow copy of the request if its context
358// contains the matched router, otherwise the same request, and a bool to
359// indicate if it the request matches one of the routes, if it doesn't
360// then it returns false and nil for the other two return values
361func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) {
362	var rCtx = request.Context()
363
364	if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok {
365		return v, request, ok
366	}
367
368	if route, ok := c.LookupRoute(request); ok {
369		rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route)
370		return route, request.WithContext(rCtx), ok
371	}
372
373	return nil, nil, false
374}
375
376// ResponseFormat negotiates the response content type
377// Returns the response format and a shallow copy of the request if its context
378// doesn't contain the response format, otherwise the same request
379func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) {
380	var rCtx = r.Context()
381
382	if v, ok := rCtx.Value(ctxResponseFormat).(string); ok {
383		debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v)
384		return v, r
385	}
386
387	format := NegotiateContentType(r, offers, "")
388	if format != "" {
389		debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format)
390		r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format))
391	}
392	debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format)
393	return format, r
394}
395
396// AllowedMethods gets the allowed methods for the path of this request
397func (c *Context) AllowedMethods(request *http.Request) []string {
398	return c.router.OtherMethods(request.Method, request.URL.EscapedPath())
399}
400
401// ResetAuth removes the current principal from the request context
402func (c *Context) ResetAuth(request *http.Request) *http.Request {
403	rctx := request.Context()
404	rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil)
405	rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil)
406	return request.WithContext(rctx)
407}
408
409// Authorize authorizes the request
410// Returns the principal object and a shallow copy of the request when its
411// context doesn't contain the principal, otherwise the same request or an error
412// (the last) if one of the authenticators returns one or an Unauthenticated error
413func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) {
414	if route == nil || !route.HasAuth() {
415		return nil, nil, nil
416	}
417
418	var rCtx = request.Context()
419	if v := rCtx.Value(ctxSecurityPrincipal); v != nil {
420		return v, request, nil
421	}
422
423	applies, usr, err := route.Authenticators.Authenticate(request, route)
424	if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil {
425		if err != nil {
426			return nil, nil, err
427		}
428		return nil, nil, errors.Unauthenticated("invalid credentials")
429	}
430	if route.Authorizer != nil {
431		if err := route.Authorizer.Authorize(request, usr); err != nil {
432			return nil, nil, errors.New(http.StatusForbidden, err.Error())
433		}
434	}
435
436	rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr)
437	rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes())
438	return usr, request.WithContext(rCtx), nil
439}
440
441// BindAndValidate binds and validates the request
442// Returns the validation map and a shallow copy of the request when its context
443// doesn't contain the validation, otherwise it returns the same request or an
444// CompositeValidationError error
445func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) {
446	var rCtx = request.Context()
447
448	if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok {
449		debugLog("got cached validation (valid: %t)", len(v.result) == 0)
450		if len(v.result) > 0 {
451			return v.bound, request, errors.CompositeValidationError(v.result...)
452		}
453		return v.bound, request, nil
454	}
455	result := validateRequest(c, request, matched)
456	rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result)
457	request = request.WithContext(rCtx)
458	if len(result.result) > 0 {
459		return result.bound, request, errors.CompositeValidationError(result.result...)
460	}
461	debugLog("no validation errors found")
462	return result.bound, request, nil
463}
464
465// NotFound the default not found responder for when no route has been matched yet
466func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) {
467	c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found"))
468}
469
470// Respond renders the response after doing some content negotiation
471func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) {
472	debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces)
473	offers := []string{}
474	for _, mt := range produces {
475		if mt != c.api.DefaultProduces() {
476			offers = append(offers, mt)
477		}
478	}
479	// the default producer is last so more specific producers take precedence
480	offers = append(offers, c.api.DefaultProduces())
481	debugLog("offers: %v", offers)
482
483	var format string
484	format, r = c.ResponseFormat(r, offers)
485	rw.Header().Set(runtime.HeaderContentType, format)
486
487	if resp, ok := data.(Responder); ok {
488		producers := route.Producers
489		prod, ok := producers[format]
490		if !ok {
491			prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
492			pr, ok := prods[c.api.DefaultProduces()]
493			if !ok {
494				panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
495			}
496			prod = pr
497		}
498		resp.WriteResponse(rw, prod)
499		return
500	}
501
502	if err, ok := data.(error); ok {
503		if format == "" {
504			rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime)
505		}
506
507		if realm := security.FailedBasicAuth(r); realm != "" {
508			rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm))
509		}
510
511		if route == nil || route.Operation == nil {
512			c.api.ServeErrorFor("")(rw, r, err)
513			return
514		}
515		c.api.ServeErrorFor(route.Operation.ID)(rw, r, err)
516		return
517	}
518
519	if route == nil || route.Operation == nil {
520		rw.WriteHeader(200)
521		if r.Method == "HEAD" {
522			return
523		}
524		producers := c.api.ProducersFor(normalizeOffers(offers))
525		prod, ok := producers[format]
526		if !ok {
527			panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
528		}
529		if err := prod.Produce(rw, data); err != nil {
530			panic(err) // let the recovery middleware deal with this
531		}
532		return
533	}
534
535	if _, code, ok := route.Operation.SuccessResponse(); ok {
536		rw.WriteHeader(code)
537		if code == 204 || r.Method == "HEAD" {
538			return
539		}
540
541		producers := route.Producers
542		prod, ok := producers[format]
543		if !ok {
544			if !ok {
545				prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()}))
546				pr, ok := prods[c.api.DefaultProduces()]
547				if !ok {
548					panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format))
549				}
550				prod = pr
551			}
552		}
553		if err := prod.Produce(rw, data); err != nil {
554			panic(err) // let the recovery middleware deal with this
555		}
556		return
557	}
558
559	c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response"))
560}
561
562// APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec
563func (c *Context) APIHandler(builder Builder) http.Handler {
564	b := builder
565	if b == nil {
566		b = PassthroughBuilder
567	}
568
569	var title string
570	sp := c.spec.Spec()
571	if sp != nil && sp.Info != nil && sp.Info.Title != "" {
572		title = sp.Info.Title
573	}
574
575	redocOpts := RedocOpts{
576		BasePath: c.BasePath(),
577		Title:    title,
578	}
579
580	return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b)))
581}
582
583// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec
584func (c *Context) RoutesHandler(builder Builder) http.Handler {
585	b := builder
586	if b == nil {
587		b = PassthroughBuilder
588	}
589	return NewRouter(c, b(NewOperationExecutor(c)))
590}
591