1// cors package is net/http handler to handle CORS related requests 2// as defined by http://www.w3.org/TR/cors/ 3// 4// You can configure it by passing an option struct to cors.New: 5// 6// c := cors.New(cors.Options{ 7// AllowedOrigins: []string{"foo.com"}, 8// AllowedMethods: []string{"GET", "POST", "DELETE"}, 9// AllowCredentials: true, 10// }) 11// 12// Then insert the handler in the chain: 13// 14// handler = c.Handler(handler) 15// 16// See Options documentation for more options. 17// 18// The resulting handler is a standard net/http handler. 19package cors 20 21import ( 22 "log" 23 "net/http" 24 "os" 25 "strconv" 26 "strings" 27) 28 29// Options is a configuration container to setup the CORS middleware. 30type Options struct { 31 // AllowedOrigins is a list of origins a cross-domain request can be executed from. 32 // If the special "*" value is present in the list, all origins will be allowed. 33 // An origin may contain a wildcard (*) to replace 0 or more characters 34 // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty. 35 // Only one wildcard can be used per origin. 36 // Default value is ["*"] 37 AllowedOrigins []string 38 39 // AllowOriginFunc is a custom function to validate the origin. It takes the origin 40 // as argument and returns true if allowed or false otherwise. If this option is 41 // set, the content of AllowedOrigins is ignored. 42 AllowOriginFunc func(r *http.Request, origin string) bool 43 44 // AllowedMethods is a list of methods the client is allowed to use with 45 // cross-domain requests. Default value is simple methods (HEAD, GET and POST). 46 AllowedMethods []string 47 48 // AllowedHeaders is list of non simple headers the client is allowed to use with 49 // cross-domain requests. 50 // If the special "*" value is present in the list, all headers will be allowed. 51 // Default value is [] but "Origin" is always appended to the list. 52 AllowedHeaders []string 53 54 // ExposedHeaders indicates which headers are safe to expose to the API of a CORS 55 // API specification 56 ExposedHeaders []string 57 58 // AllowCredentials indicates whether the request can include user credentials like 59 // cookies, HTTP authentication or client side SSL certificates. 60 AllowCredentials bool 61 62 // MaxAge indicates how long (in seconds) the results of a preflight request 63 // can be cached 64 MaxAge int 65 66 // OptionsPassthrough instructs preflight to let other potential next handlers to 67 // process the OPTIONS method. Turn this on if your application handles OPTIONS. 68 OptionsPassthrough bool 69 70 // Debugging flag adds additional output to debug server side CORS issues 71 Debug bool 72} 73 74// Logger generic interface for logger 75type Logger interface { 76 Printf(string, ...interface{}) 77} 78 79// Cors http handler 80type Cors struct { 81 // Debug logger 82 Log Logger 83 84 // Normalized list of plain allowed origins 85 allowedOrigins []string 86 87 // List of allowed origins containing wildcards 88 allowedWOrigins []wildcard 89 90 // Optional origin validator function 91 allowOriginFunc func(r *http.Request, origin string) bool 92 93 // Normalized list of allowed headers 94 allowedHeaders []string 95 96 // Normalized list of allowed methods 97 allowedMethods []string 98 99 // Normalized list of exposed headers 100 exposedHeaders []string 101 maxAge int 102 103 // Set to true when allowed origins contains a "*" 104 allowedOriginsAll bool 105 106 // Set to true when allowed headers contains a "*" 107 allowedHeadersAll bool 108 109 allowCredentials bool 110 optionPassthrough bool 111} 112 113// New creates a new Cors handler with the provided options. 114func New(options Options) *Cors { 115 c := &Cors{ 116 exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), 117 allowOriginFunc: options.AllowOriginFunc, 118 allowCredentials: options.AllowCredentials, 119 maxAge: options.MaxAge, 120 optionPassthrough: options.OptionsPassthrough, 121 } 122 if options.Debug && c.Log == nil { 123 c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags) 124 } 125 126 // Normalize options 127 // Note: for origins and methods matching, the spec requires a case-sensitive matching. 128 // As it may error prone, we chose to ignore the spec here. 129 130 // Allowed Origins 131 if len(options.AllowedOrigins) == 0 { 132 if options.AllowOriginFunc == nil { 133 // Default is all origins 134 c.allowedOriginsAll = true 135 } 136 } else { 137 c.allowedOrigins = []string{} 138 c.allowedWOrigins = []wildcard{} 139 for _, origin := range options.AllowedOrigins { 140 // Normalize 141 origin = strings.ToLower(origin) 142 if origin == "*" { 143 // If "*" is present in the list, turn the whole list into a match all 144 c.allowedOriginsAll = true 145 c.allowedOrigins = nil 146 c.allowedWOrigins = nil 147 break 148 } else if i := strings.IndexByte(origin, '*'); i >= 0 { 149 // Split the origin in two: start and end string without the * 150 w := wildcard{origin[0:i], origin[i+1:]} 151 c.allowedWOrigins = append(c.allowedWOrigins, w) 152 } else { 153 c.allowedOrigins = append(c.allowedOrigins, origin) 154 } 155 } 156 } 157 158 // Allowed Headers 159 if len(options.AllowedHeaders) == 0 { 160 // Use sensible defaults 161 c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"} 162 } else { 163 // Origin is always appended as some browsers will always request for this header at preflight 164 c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) 165 for _, h := range options.AllowedHeaders { 166 if h == "*" { 167 c.allowedHeadersAll = true 168 c.allowedHeaders = nil 169 break 170 } 171 } 172 } 173 174 // Allowed Methods 175 if len(options.AllowedMethods) == 0 { 176 // Default is spec's "simple" methods 177 c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead} 178 } else { 179 c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper) 180 } 181 182 return c 183} 184 185// Handler creates a new Cors handler with passed options. 186func Handler(options Options) func(next http.Handler) http.Handler { 187 c := New(options) 188 return c.Handler 189} 190 191// AllowAll create a new Cors handler with permissive configuration allowing all 192// origins with all standard methods with any header and credentials. 193func AllowAll() *Cors { 194 return New(Options{ 195 AllowedOrigins: []string{"*"}, 196 AllowedMethods: []string{ 197 http.MethodHead, 198 http.MethodGet, 199 http.MethodPost, 200 http.MethodPut, 201 http.MethodPatch, 202 http.MethodDelete, 203 }, 204 AllowedHeaders: []string{"*"}, 205 AllowCredentials: false, 206 }) 207} 208 209// Handler apply the CORS specification on the request, and add relevant CORS headers 210// as necessary. 211func (c *Cors) Handler(next http.Handler) http.Handler { 212 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 213 if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { 214 c.logf("Handler: Preflight request") 215 c.handlePreflight(w, r) 216 // Preflight requests are standalone and should stop the chain as some other 217 // middleware may not handle OPTIONS requests correctly. One typical example 218 // is authentication middleware ; OPTIONS requests won't carry authentication 219 // headers (see #1) 220 if c.optionPassthrough { 221 next.ServeHTTP(w, r) 222 } else { 223 w.WriteHeader(http.StatusOK) 224 } 225 } else { 226 c.logf("Handler: Actual request") 227 c.handleActualRequest(w, r) 228 next.ServeHTTP(w, r) 229 } 230 }) 231} 232 233// handlePreflight handles pre-flight CORS requests 234func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { 235 headers := w.Header() 236 origin := r.Header.Get("Origin") 237 238 if r.Method != http.MethodOptions { 239 c.logf("Preflight aborted: %s!=OPTIONS", r.Method) 240 return 241 } 242 // Always set Vary headers 243 // see https://github.com/rs/cors/issues/10, 244 // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 245 headers.Add("Vary", "Origin") 246 headers.Add("Vary", "Access-Control-Request-Method") 247 headers.Add("Vary", "Access-Control-Request-Headers") 248 249 if origin == "" { 250 c.logf("Preflight aborted: empty origin") 251 return 252 } 253 if !c.isOriginAllowed(r, origin) { 254 c.logf("Preflight aborted: origin '%s' not allowed", origin) 255 return 256 } 257 258 reqMethod := r.Header.Get("Access-Control-Request-Method") 259 if !c.isMethodAllowed(reqMethod) { 260 c.logf("Preflight aborted: method '%s' not allowed", reqMethod) 261 return 262 } 263 reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers")) 264 if !c.areHeadersAllowed(reqHeaders) { 265 c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders) 266 return 267 } 268 if c.allowedOriginsAll { 269 headers.Set("Access-Control-Allow-Origin", "*") 270 } else { 271 headers.Set("Access-Control-Allow-Origin", origin) 272 } 273 // Spec says: Since the list of methods can be unbounded, simply returning the method indicated 274 // by Access-Control-Request-Method (if supported) can be enough 275 headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) 276 if len(reqHeaders) > 0 { 277 278 // Spec says: Since the list of headers can be unbounded, simply returning supported headers 279 // from Access-Control-Request-Headers can be enough 280 headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) 281 } 282 if c.allowCredentials { 283 headers.Set("Access-Control-Allow-Credentials", "true") 284 } 285 if c.maxAge > 0 { 286 headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge)) 287 } 288 c.logf("Preflight response headers: %v", headers) 289} 290 291// handleActualRequest handles simple cross-origin requests, actual request or redirects 292func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { 293 headers := w.Header() 294 origin := r.Header.Get("Origin") 295 296 // Always set Vary, see https://github.com/rs/cors/issues/10 297 headers.Add("Vary", "Origin") 298 if origin == "" { 299 c.logf("Actual request no headers added: missing origin") 300 return 301 } 302 if !c.isOriginAllowed(r, origin) { 303 c.logf("Actual request no headers added: origin '%s' not allowed", origin) 304 return 305 } 306 307 // Note that spec does define a way to specifically disallow a simple method like GET or 308 // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the 309 // spec doesn't instruct to check the allowed methods for simple cross-origin requests. 310 // We think it's a nice feature to be able to have control on those methods though. 311 if !c.isMethodAllowed(r.Method) { 312 c.logf("Actual request no headers added: method '%s' not allowed", r.Method) 313 314 return 315 } 316 if c.allowedOriginsAll { 317 headers.Set("Access-Control-Allow-Origin", "*") 318 } else { 319 headers.Set("Access-Control-Allow-Origin", origin) 320 } 321 if len(c.exposedHeaders) > 0 { 322 headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) 323 } 324 if c.allowCredentials { 325 headers.Set("Access-Control-Allow-Credentials", "true") 326 } 327 c.logf("Actual response added headers: %v", headers) 328} 329 330// convenience method. checks if a logger is set. 331func (c *Cors) logf(format string, a ...interface{}) { 332 if c.Log != nil { 333 c.Log.Printf(format, a...) 334 } 335} 336 337// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests 338// on the endpoint 339func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { 340 if c.allowOriginFunc != nil { 341 return c.allowOriginFunc(r, origin) 342 } 343 if c.allowedOriginsAll { 344 return true 345 } 346 origin = strings.ToLower(origin) 347 for _, o := range c.allowedOrigins { 348 if o == origin { 349 return true 350 } 351 } 352 for _, w := range c.allowedWOrigins { 353 if w.match(origin) { 354 return true 355 } 356 } 357 return false 358} 359 360// isMethodAllowed checks if a given method can be used as part of a cross-domain request 361// on the endpoint 362func (c *Cors) isMethodAllowed(method string) bool { 363 if len(c.allowedMethods) == 0 { 364 // If no method allowed, always return false, even for preflight request 365 return false 366 } 367 method = strings.ToUpper(method) 368 if method == http.MethodOptions { 369 // Always allow preflight requests 370 return true 371 } 372 for _, m := range c.allowedMethods { 373 if m == method { 374 return true 375 } 376 } 377 return false 378} 379 380// areHeadersAllowed checks if a given list of headers are allowed to used within 381// a cross-domain request. 382func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { 383 if c.allowedHeadersAll || len(requestedHeaders) == 0 { 384 return true 385 } 386 for _, header := range requestedHeaders { 387 header = http.CanonicalHeaderKey(header) 388 found := false 389 for _, h := range c.allowedHeaders { 390 if h == header { 391 found = true 392 break 393 } 394 } 395 if !found { 396 return false 397 } 398 } 399 return true 400} 401