1// Copyright 2015 go-swagger maintainers 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package middleware 16 17import ( 18 stdContext "context" 19 "fmt" 20 "net/http" 21 "strings" 22 "sync" 23 24 "github.com/go-openapi/runtime/security" 25 26 "github.com/go-openapi/analysis" 27 "github.com/go-openapi/errors" 28 "github.com/go-openapi/loads" 29 "github.com/go-openapi/runtime" 30 "github.com/go-openapi/runtime/logger" 31 "github.com/go-openapi/runtime/middleware/untyped" 32 "github.com/go-openapi/spec" 33 "github.com/go-openapi/strfmt" 34) 35 36// Debug when true turns on verbose logging 37var Debug = logger.DebugEnabled() 38var Logger logger.Logger = logger.StandardLogger{} 39 40func debugLog(format string, args ...interface{}) { 41 if Debug { 42 Logger.Printf(format, args...) 43 } 44} 45 46// A Builder can create middlewares 47type Builder func(http.Handler) http.Handler 48 49// PassthroughBuilder returns the handler, aka the builder identity function 50func PassthroughBuilder(handler http.Handler) http.Handler { return handler } 51 52// RequestBinder is an interface for types to implement 53// when they want to be able to bind from a request 54type RequestBinder interface { 55 BindRequest(*http.Request, *MatchedRoute) error 56} 57 58// Responder is an interface for types to implement 59// when they want to be considered for writing HTTP responses 60type Responder interface { 61 WriteResponse(http.ResponseWriter, runtime.Producer) 62} 63 64// ResponderFunc wraps a func as a Responder interface 65type ResponderFunc func(http.ResponseWriter, runtime.Producer) 66 67// WriteResponse writes to the response 68func (fn ResponderFunc) WriteResponse(rw http.ResponseWriter, pr runtime.Producer) { 69 fn(rw, pr) 70} 71 72// Context is a type safe wrapper around an untyped request context 73// used throughout to store request context with the standard context attached 74// to the http.Request 75type Context struct { 76 spec *loads.Document 77 analyzer *analysis.Spec 78 api RoutableAPI 79 router Router 80} 81 82type routableUntypedAPI struct { 83 api *untyped.API 84 hlock *sync.Mutex 85 handlers map[string]map[string]http.Handler 86 defaultConsumes string 87 defaultProduces string 88} 89 90func newRoutableUntypedAPI(spec *loads.Document, api *untyped.API, context *Context) *routableUntypedAPI { 91 var handlers map[string]map[string]http.Handler 92 if spec == nil || api == nil { 93 return nil 94 } 95 analyzer := analysis.New(spec.Spec()) 96 for method, hls := range analyzer.Operations() { 97 um := strings.ToUpper(method) 98 for path, op := range hls { 99 schemes := analyzer.SecurityRequirementsFor(op) 100 101 if oh, ok := api.OperationHandlerFor(method, path); ok { 102 if handlers == nil { 103 handlers = make(map[string]map[string]http.Handler) 104 } 105 if b, ok := handlers[um]; !ok || b == nil { 106 handlers[um] = make(map[string]http.Handler) 107 } 108 109 var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 110 // lookup route info in the context 111 route, rCtx, _ := context.RouteInfo(r) 112 if rCtx != nil { 113 r = rCtx 114 } 115 116 // bind and validate the request using reflection 117 var bound interface{} 118 var validation error 119 bound, r, validation = context.BindAndValidate(r, route) 120 if validation != nil { 121 context.Respond(w, r, route.Produces, route, validation) 122 return 123 } 124 125 // actually handle the request 126 result, err := oh.Handle(bound) 127 if err != nil { 128 // respond with failure 129 context.Respond(w, r, route.Produces, route, err) 130 return 131 } 132 133 // respond with success 134 context.Respond(w, r, route.Produces, route, result) 135 }) 136 137 if len(schemes) > 0 { 138 handler = newSecureAPI(context, handler) 139 } 140 handlers[um][path] = handler 141 } 142 } 143 } 144 145 return &routableUntypedAPI{ 146 api: api, 147 hlock: new(sync.Mutex), 148 handlers: handlers, 149 defaultProduces: api.DefaultProduces, 150 defaultConsumes: api.DefaultConsumes, 151 } 152} 153 154func (r *routableUntypedAPI) HandlerFor(method, path string) (http.Handler, bool) { 155 r.hlock.Lock() 156 paths, ok := r.handlers[strings.ToUpper(method)] 157 if !ok { 158 r.hlock.Unlock() 159 return nil, false 160 } 161 handler, ok := paths[path] 162 r.hlock.Unlock() 163 return handler, ok 164} 165func (r *routableUntypedAPI) ServeErrorFor(operationID string) func(http.ResponseWriter, *http.Request, error) { 166 return r.api.ServeError 167} 168func (r *routableUntypedAPI) ConsumersFor(mediaTypes []string) map[string]runtime.Consumer { 169 return r.api.ConsumersFor(mediaTypes) 170} 171func (r *routableUntypedAPI) ProducersFor(mediaTypes []string) map[string]runtime.Producer { 172 return r.api.ProducersFor(mediaTypes) 173} 174func (r *routableUntypedAPI) AuthenticatorsFor(schemes map[string]spec.SecurityScheme) map[string]runtime.Authenticator { 175 return r.api.AuthenticatorsFor(schemes) 176} 177func (r *routableUntypedAPI) Authorizer() runtime.Authorizer { 178 return r.api.Authorizer() 179} 180func (r *routableUntypedAPI) Formats() strfmt.Registry { 181 return r.api.Formats() 182} 183 184func (r *routableUntypedAPI) DefaultProduces() string { 185 return r.defaultProduces 186} 187 188func (r *routableUntypedAPI) DefaultConsumes() string { 189 return r.defaultConsumes 190} 191 192// NewRoutableContext creates a new context for a routable API 193func NewRoutableContext(spec *loads.Document, routableAPI RoutableAPI, routes Router) *Context { 194 var an *analysis.Spec 195 if spec != nil { 196 an = analysis.New(spec.Spec()) 197 } 198 ctx := &Context{spec: spec, api: routableAPI, analyzer: an, router: routes} 199 return ctx 200} 201 202// NewContext creates a new context wrapper 203func NewContext(spec *loads.Document, api *untyped.API, routes Router) *Context { 204 var an *analysis.Spec 205 if spec != nil { 206 an = analysis.New(spec.Spec()) 207 } 208 ctx := &Context{spec: spec, analyzer: an} 209 ctx.api = newRoutableUntypedAPI(spec, api, ctx) 210 ctx.router = routes 211 return ctx 212} 213 214// Serve serves the specified spec with the specified api registrations as a http.Handler 215func Serve(spec *loads.Document, api *untyped.API) http.Handler { 216 return ServeWithBuilder(spec, api, PassthroughBuilder) 217} 218 219// ServeWithBuilder serves the specified spec with the specified api registrations as a http.Handler that is decorated 220// by the Builder 221func ServeWithBuilder(spec *loads.Document, api *untyped.API, builder Builder) http.Handler { 222 context := NewContext(spec, api, nil) 223 return context.APIHandler(builder) 224} 225 226type contextKey int8 227 228const ( 229 _ contextKey = iota 230 ctxContentType 231 ctxResponseFormat 232 ctxMatchedRoute 233 ctxBoundParams 234 ctxSecurityPrincipal 235 ctxSecurityScopes 236) 237 238// MatchedRouteFrom request context value. 239func MatchedRouteFrom(req *http.Request) *MatchedRoute { 240 mr := req.Context().Value(ctxMatchedRoute) 241 if mr == nil { 242 return nil 243 } 244 if res, ok := mr.(*MatchedRoute); ok { 245 return res 246 } 247 return nil 248} 249 250// SecurityPrincipalFrom request context value. 251func SecurityPrincipalFrom(req *http.Request) interface{} { 252 return req.Context().Value(ctxSecurityPrincipal) 253} 254 255// SecurityScopesFrom request context value. 256func SecurityScopesFrom(req *http.Request) []string { 257 rs := req.Context().Value(ctxSecurityScopes) 258 if res, ok := rs.([]string); ok { 259 return res 260 } 261 return nil 262} 263 264type contentTypeValue struct { 265 MediaType string 266 Charset string 267} 268 269// BasePath returns the base path for this API 270func (c *Context) BasePath() string { 271 return c.spec.BasePath() 272} 273 274// RequiredProduces returns the accepted content types for responses 275func (c *Context) RequiredProduces() []string { 276 return c.analyzer.RequiredProduces() 277} 278 279// BindValidRequest binds a params object to a request but only when the request is valid 280// if the request is not valid an error will be returned 281func (c *Context) BindValidRequest(request *http.Request, route *MatchedRoute, binder RequestBinder) error { 282 var res []error 283 284 requestContentType := "*/*" 285 // check and validate content type, select consumer 286 if runtime.HasBody(request) { 287 ct, _, err := runtime.ContentType(request.Header) 288 if err != nil { 289 res = append(res, err) 290 } else { 291 if err := validateContentType(route.Consumes, ct); err != nil { 292 res = append(res, err) 293 } 294 if len(res) == 0 { 295 cons, ok := route.Consumers[ct] 296 if !ok { 297 res = append(res, errors.New(500, "no consumer registered for %s", ct)) 298 } else { 299 route.Consumer = cons 300 requestContentType = ct 301 } 302 } 303 } 304 } 305 306 // check and validate the response format 307 if len(res) == 0 && runtime.HasBody(request) { 308 if str := NegotiateContentType(request, route.Produces, requestContentType); str == "" { 309 res = append(res, errors.InvalidResponseFormat(request.Header.Get(runtime.HeaderAccept), route.Produces)) 310 } 311 } 312 313 // now bind the request with the provided binder 314 // it's assumed the binder will also validate the request and return an error if the 315 // request is invalid 316 if binder != nil && len(res) == 0 { 317 if err := binder.BindRequest(request, route); err != nil { 318 return err 319 } 320 } 321 322 if len(res) > 0 { 323 return errors.CompositeValidationError(res...) 324 } 325 return nil 326} 327 328// ContentType gets the parsed value of a content type 329// Returns the media type, its charset and a shallow copy of the request 330// when its context doesn't contain the content type value, otherwise it returns 331// the same request 332// Returns the error that runtime.ContentType may retunrs. 333func (c *Context) ContentType(request *http.Request) (string, string, *http.Request, error) { 334 var rCtx = request.Context() 335 336 if v, ok := rCtx.Value(ctxContentType).(*contentTypeValue); ok { 337 return v.MediaType, v.Charset, request, nil 338 } 339 340 mt, cs, err := runtime.ContentType(request.Header) 341 if err != nil { 342 return "", "", nil, err 343 } 344 rCtx = stdContext.WithValue(rCtx, ctxContentType, &contentTypeValue{mt, cs}) 345 return mt, cs, request.WithContext(rCtx), nil 346} 347 348// LookupRoute looks a route up and returns true when it is found 349func (c *Context) LookupRoute(request *http.Request) (*MatchedRoute, bool) { 350 if route, ok := c.router.Lookup(request.Method, request.URL.EscapedPath()); ok { 351 return route, ok 352 } 353 return nil, false 354} 355 356// RouteInfo tries to match a route for this request 357// Returns the matched route, a shallow copy of the request if its context 358// contains the matched router, otherwise the same request, and a bool to 359// indicate if it the request matches one of the routes, if it doesn't 360// then it returns false and nil for the other two return values 361func (c *Context) RouteInfo(request *http.Request) (*MatchedRoute, *http.Request, bool) { 362 var rCtx = request.Context() 363 364 if v, ok := rCtx.Value(ctxMatchedRoute).(*MatchedRoute); ok { 365 return v, request, ok 366 } 367 368 if route, ok := c.LookupRoute(request); ok { 369 rCtx = stdContext.WithValue(rCtx, ctxMatchedRoute, route) 370 return route, request.WithContext(rCtx), ok 371 } 372 373 return nil, nil, false 374} 375 376// ResponseFormat negotiates the response content type 377// Returns the response format and a shallow copy of the request if its context 378// doesn't contain the response format, otherwise the same request 379func (c *Context) ResponseFormat(r *http.Request, offers []string) (string, *http.Request) { 380 var rCtx = r.Context() 381 382 if v, ok := rCtx.Value(ctxResponseFormat).(string); ok { 383 debugLog("[%s %s] found response format %q in context", r.Method, r.URL.Path, v) 384 return v, r 385 } 386 387 format := NegotiateContentType(r, offers, "") 388 if format != "" { 389 debugLog("[%s %s] set response format %q in context", r.Method, r.URL.Path, format) 390 r = r.WithContext(stdContext.WithValue(rCtx, ctxResponseFormat, format)) 391 } 392 debugLog("[%s %s] negotiated response format %q", r.Method, r.URL.Path, format) 393 return format, r 394} 395 396// AllowedMethods gets the allowed methods for the path of this request 397func (c *Context) AllowedMethods(request *http.Request) []string { 398 return c.router.OtherMethods(request.Method, request.URL.EscapedPath()) 399} 400 401// ResetAuth removes the current principal from the request context 402func (c *Context) ResetAuth(request *http.Request) *http.Request { 403 rctx := request.Context() 404 rctx = stdContext.WithValue(rctx, ctxSecurityPrincipal, nil) 405 rctx = stdContext.WithValue(rctx, ctxSecurityScopes, nil) 406 return request.WithContext(rctx) 407} 408 409// Authorize authorizes the request 410// Returns the principal object and a shallow copy of the request when its 411// context doesn't contain the principal, otherwise the same request or an error 412// (the last) if one of the authenticators returns one or an Unauthenticated error 413func (c *Context) Authorize(request *http.Request, route *MatchedRoute) (interface{}, *http.Request, error) { 414 if route == nil || !route.HasAuth() { 415 return nil, nil, nil 416 } 417 418 var rCtx = request.Context() 419 if v := rCtx.Value(ctxSecurityPrincipal); v != nil { 420 return v, request, nil 421 } 422 423 applies, usr, err := route.Authenticators.Authenticate(request, route) 424 if !applies || err != nil || !route.Authenticators.AllowsAnonymous() && usr == nil { 425 if err != nil { 426 return nil, nil, err 427 } 428 return nil, nil, errors.Unauthenticated("invalid credentials") 429 } 430 if route.Authorizer != nil { 431 if err := route.Authorizer.Authorize(request, usr); err != nil { 432 return nil, nil, errors.New(http.StatusForbidden, err.Error()) 433 } 434 } 435 436 rCtx = stdContext.WithValue(rCtx, ctxSecurityPrincipal, usr) 437 rCtx = stdContext.WithValue(rCtx, ctxSecurityScopes, route.Authenticator.AllScopes()) 438 return usr, request.WithContext(rCtx), nil 439} 440 441// BindAndValidate binds and validates the request 442// Returns the validation map and a shallow copy of the request when its context 443// doesn't contain the validation, otherwise it returns the same request or an 444// CompositeValidationError error 445func (c *Context) BindAndValidate(request *http.Request, matched *MatchedRoute) (interface{}, *http.Request, error) { 446 var rCtx = request.Context() 447 448 if v, ok := rCtx.Value(ctxBoundParams).(*validation); ok { 449 debugLog("got cached validation (valid: %t)", len(v.result) == 0) 450 if len(v.result) > 0 { 451 return v.bound, request, errors.CompositeValidationError(v.result...) 452 } 453 return v.bound, request, nil 454 } 455 result := validateRequest(c, request, matched) 456 rCtx = stdContext.WithValue(rCtx, ctxBoundParams, result) 457 request = request.WithContext(rCtx) 458 if len(result.result) > 0 { 459 return result.bound, request, errors.CompositeValidationError(result.result...) 460 } 461 debugLog("no validation errors found") 462 return result.bound, request, nil 463} 464 465// NotFound the default not found responder for when no route has been matched yet 466func (c *Context) NotFound(rw http.ResponseWriter, r *http.Request) { 467 c.Respond(rw, r, []string{c.api.DefaultProduces()}, nil, errors.NotFound("not found")) 468} 469 470// Respond renders the response after doing some content negotiation 471func (c *Context) Respond(rw http.ResponseWriter, r *http.Request, produces []string, route *MatchedRoute, data interface{}) { 472 debugLog("responding to %s %s with produces: %v", r.Method, r.URL.Path, produces) 473 offers := []string{} 474 for _, mt := range produces { 475 if mt != c.api.DefaultProduces() { 476 offers = append(offers, mt) 477 } 478 } 479 // the default producer is last so more specific producers take precedence 480 offers = append(offers, c.api.DefaultProduces()) 481 debugLog("offers: %v", offers) 482 483 var format string 484 format, r = c.ResponseFormat(r, offers) 485 rw.Header().Set(runtime.HeaderContentType, format) 486 487 if resp, ok := data.(Responder); ok { 488 producers := route.Producers 489 prod, ok := producers[format] 490 if !ok { 491 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) 492 pr, ok := prods[c.api.DefaultProduces()] 493 if !ok { 494 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) 495 } 496 prod = pr 497 } 498 resp.WriteResponse(rw, prod) 499 return 500 } 501 502 if err, ok := data.(error); ok { 503 if format == "" { 504 rw.Header().Set(runtime.HeaderContentType, runtime.JSONMime) 505 } 506 507 if realm := security.FailedBasicAuth(r); realm != "" { 508 rw.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", realm)) 509 } 510 511 if route == nil || route.Operation == nil { 512 c.api.ServeErrorFor("")(rw, r, err) 513 return 514 } 515 c.api.ServeErrorFor(route.Operation.ID)(rw, r, err) 516 return 517 } 518 519 if route == nil || route.Operation == nil { 520 rw.WriteHeader(200) 521 if r.Method == "HEAD" { 522 return 523 } 524 producers := c.api.ProducersFor(normalizeOffers(offers)) 525 prod, ok := producers[format] 526 if !ok { 527 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) 528 } 529 if err := prod.Produce(rw, data); err != nil { 530 panic(err) // let the recovery middleware deal with this 531 } 532 return 533 } 534 535 if _, code, ok := route.Operation.SuccessResponse(); ok { 536 rw.WriteHeader(code) 537 if code == 204 || r.Method == "HEAD" { 538 return 539 } 540 541 producers := route.Producers 542 prod, ok := producers[format] 543 if !ok { 544 if !ok { 545 prods := c.api.ProducersFor(normalizeOffers([]string{c.api.DefaultProduces()})) 546 pr, ok := prods[c.api.DefaultProduces()] 547 if !ok { 548 panic(errors.New(http.StatusInternalServerError, "can't find a producer for "+format)) 549 } 550 prod = pr 551 } 552 } 553 if err := prod.Produce(rw, data); err != nil { 554 panic(err) // let the recovery middleware deal with this 555 } 556 return 557 } 558 559 c.api.ServeErrorFor(route.Operation.ID)(rw, r, errors.New(http.StatusInternalServerError, "can't produce response")) 560} 561 562// APIHandler returns a handler to serve the API, this includes a swagger spec, router and the contract defined in the swagger spec 563func (c *Context) APIHandler(builder Builder) http.Handler { 564 b := builder 565 if b == nil { 566 b = PassthroughBuilder 567 } 568 569 var title string 570 sp := c.spec.Spec() 571 if sp != nil && sp.Info != nil && sp.Info.Title != "" { 572 title = sp.Info.Title 573 } 574 575 redocOpts := RedocOpts{ 576 BasePath: c.BasePath(), 577 Title: title, 578 } 579 580 return Spec("", c.spec.Raw(), Redoc(redocOpts, c.RoutesHandler(b))) 581} 582 583// RoutesHandler returns a handler to serve the API, just the routes and the contract defined in the swagger spec 584func (c *Context) RoutesHandler(builder Builder) http.Handler { 585 b := builder 586 if b == nil { 587 b = PassthroughBuilder 588 } 589 return NewRouter(c, b(NewOperationExecutor(c))) 590} 591