1// cors package is net/http handler to handle CORS related requests
2// as defined by http://www.w3.org/TR/cors/
3//
4// You can configure it by passing an option struct to cors.New:
5//
6//     c := cors.New(cors.Options{
7//         AllowedOrigins: []string{"foo.com"},
8//         AllowedMethods: []string{"GET", "POST", "DELETE"},
9//         AllowCredentials: true,
10//     })
11//
12// Then insert the handler in the chain:
13//
14//     handler = c.Handler(handler)
15//
16// See Options documentation for more options.
17//
18// The resulting handler is a standard net/http handler.
19package cors
20
21import (
22	"log"
23	"net/http"
24	"os"
25	"strconv"
26	"strings"
27)
28
29// Options is a configuration container to setup the CORS middleware.
30type Options struct {
31	// AllowedOrigins is a list of origins a cross-domain request can be executed from.
32	// If the special "*" value is present in the list, all origins will be allowed.
33	// An origin may contain a wildcard (*) to replace 0 or more characters
34	// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
35	// Only one wildcard can be used per origin.
36	// Default value is ["*"]
37	AllowedOrigins []string
38
39	// AllowOriginFunc is a custom function to validate the origin. It takes the origin
40	// as argument and returns true if allowed or false otherwise. If this option is
41	// set, the content of AllowedOrigins is ignored.
42	AllowOriginFunc func(r *http.Request, origin string) bool
43
44	// AllowedMethods is a list of methods the client is allowed to use with
45	// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
46	AllowedMethods []string
47
48	// AllowedHeaders is list of non simple headers the client is allowed to use with
49	// cross-domain requests.
50	// If the special "*" value is present in the list, all headers will be allowed.
51	// Default value is [] but "Origin" is always appended to the list.
52	AllowedHeaders []string
53
54	// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
55	// API specification
56	ExposedHeaders []string
57
58	// AllowCredentials indicates whether the request can include user credentials like
59	// cookies, HTTP authentication or client side SSL certificates.
60	AllowCredentials bool
61
62	// MaxAge indicates how long (in seconds) the results of a preflight request
63	// can be cached
64	MaxAge int
65
66	// OptionsPassthrough instructs preflight to let other potential next handlers to
67	// process the OPTIONS method. Turn this on if your application handles OPTIONS.
68	OptionsPassthrough bool
69
70	// Debugging flag adds additional output to debug server side CORS issues
71	Debug bool
72}
73
74// Logger generic interface for logger
75type Logger interface {
76	Printf(string, ...interface{})
77}
78
79// Cors http handler
80type Cors struct {
81	// Debug logger
82	Log Logger
83
84	// Normalized list of plain allowed origins
85	allowedOrigins []string
86
87	// List of allowed origins containing wildcards
88	allowedWOrigins []wildcard
89
90	// Optional origin validator function
91	allowOriginFunc func(r *http.Request, origin string) bool
92
93	// Normalized list of allowed headers
94	allowedHeaders []string
95
96	// Normalized list of allowed methods
97	allowedMethods []string
98
99	// Normalized list of exposed headers
100	exposedHeaders []string
101	maxAge         int
102
103	// Set to true when allowed origins contains a "*"
104	allowedOriginsAll bool
105
106	// Set to true when allowed headers contains a "*"
107	allowedHeadersAll bool
108
109	allowCredentials  bool
110	optionPassthrough bool
111}
112
113// New creates a new Cors handler with the provided options.
114func New(options Options) *Cors {
115	c := &Cors{
116		exposedHeaders:    convert(options.ExposedHeaders, http.CanonicalHeaderKey),
117		allowOriginFunc:   options.AllowOriginFunc,
118		allowCredentials:  options.AllowCredentials,
119		maxAge:            options.MaxAge,
120		optionPassthrough: options.OptionsPassthrough,
121	}
122	if options.Debug && c.Log == nil {
123		c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
124	}
125
126	// Normalize options
127	// Note: for origins and methods matching, the spec requires a case-sensitive matching.
128	// As it may error prone, we chose to ignore the spec here.
129
130	// Allowed Origins
131	if len(options.AllowedOrigins) == 0 {
132		if options.AllowOriginFunc == nil {
133			// Default is all origins
134			c.allowedOriginsAll = true
135		}
136	} else {
137		c.allowedOrigins = []string{}
138		c.allowedWOrigins = []wildcard{}
139		for _, origin := range options.AllowedOrigins {
140			// Normalize
141			origin = strings.ToLower(origin)
142			if origin == "*" {
143				// If "*" is present in the list, turn the whole list into a match all
144				c.allowedOriginsAll = true
145				c.allowedOrigins = nil
146				c.allowedWOrigins = nil
147				break
148			} else if i := strings.IndexByte(origin, '*'); i >= 0 {
149				// Split the origin in two: start and end string without the *
150				w := wildcard{origin[0:i], origin[i+1:]}
151				c.allowedWOrigins = append(c.allowedWOrigins, w)
152			} else {
153				c.allowedOrigins = append(c.allowedOrigins, origin)
154			}
155		}
156	}
157
158	// Allowed Headers
159	if len(options.AllowedHeaders) == 0 {
160		// Use sensible defaults
161		c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
162	} else {
163		// Origin is always appended as some browsers will always request for this header at preflight
164		c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
165		for _, h := range options.AllowedHeaders {
166			if h == "*" {
167				c.allowedHeadersAll = true
168				c.allowedHeaders = nil
169				break
170			}
171		}
172	}
173
174	// Allowed Methods
175	if len(options.AllowedMethods) == 0 {
176		// Default is spec's "simple" methods
177		c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead}
178	} else {
179		c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
180	}
181
182	return c
183}
184
185// Handler creates a new Cors handler with passed options.
186func Handler(options Options) func(next http.Handler) http.Handler {
187	c := New(options)
188	return c.Handler
189}
190
191// AllowAll create a new Cors handler with permissive configuration allowing all
192// origins with all standard methods with any header and credentials.
193func AllowAll() *Cors {
194	return New(Options{
195		AllowedOrigins: []string{"*"},
196		AllowedMethods: []string{
197			http.MethodHead,
198			http.MethodGet,
199			http.MethodPost,
200			http.MethodPut,
201			http.MethodPatch,
202			http.MethodDelete,
203		},
204		AllowedHeaders:   []string{"*"},
205		AllowCredentials: false,
206	})
207}
208
209// Handler apply the CORS specification on the request, and add relevant CORS headers
210// as necessary.
211func (c *Cors) Handler(next http.Handler) http.Handler {
212	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
213		if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
214			c.logf("Handler: Preflight request")
215			c.handlePreflight(w, r)
216			// Preflight requests are standalone and should stop the chain as some other
217			// middleware may not handle OPTIONS requests correctly. One typical example
218			// is authentication middleware ; OPTIONS requests won't carry authentication
219			// headers (see #1)
220			if c.optionPassthrough {
221				next.ServeHTTP(w, r)
222			} else {
223				w.WriteHeader(http.StatusOK)
224			}
225		} else {
226			c.logf("Handler: Actual request")
227			c.handleActualRequest(w, r)
228			next.ServeHTTP(w, r)
229		}
230	})
231}
232
233// handlePreflight handles pre-flight CORS requests
234func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
235	headers := w.Header()
236	origin := r.Header.Get("Origin")
237
238	if r.Method != http.MethodOptions {
239		c.logf("Preflight aborted: %s!=OPTIONS", r.Method)
240		return
241	}
242	// Always set Vary headers
243	// see https://github.com/rs/cors/issues/10,
244	//     https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
245	headers.Add("Vary", "Origin")
246	headers.Add("Vary", "Access-Control-Request-Method")
247	headers.Add("Vary", "Access-Control-Request-Headers")
248
249	if origin == "" {
250		c.logf("Preflight aborted: empty origin")
251		return
252	}
253	if !c.isOriginAllowed(r, origin) {
254		c.logf("Preflight aborted: origin '%s' not allowed", origin)
255		return
256	}
257
258	reqMethod := r.Header.Get("Access-Control-Request-Method")
259	if !c.isMethodAllowed(reqMethod) {
260		c.logf("Preflight aborted: method '%s' not allowed", reqMethod)
261		return
262	}
263	reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
264	if !c.areHeadersAllowed(reqHeaders) {
265		c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders)
266		return
267	}
268	if c.allowedOriginsAll {
269		headers.Set("Access-Control-Allow-Origin", "*")
270	} else {
271		headers.Set("Access-Control-Allow-Origin", origin)
272	}
273	// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
274	// by Access-Control-Request-Method (if supported) can be enough
275	headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
276	if len(reqHeaders) > 0 {
277
278		// Spec says: Since the list of headers can be unbounded, simply returning supported headers
279		// from Access-Control-Request-Headers can be enough
280		headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
281	}
282	if c.allowCredentials {
283		headers.Set("Access-Control-Allow-Credentials", "true")
284	}
285	if c.maxAge > 0 {
286		headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
287	}
288	c.logf("Preflight response headers: %v", headers)
289}
290
291// handleActualRequest handles simple cross-origin requests, actual request or redirects
292func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
293	headers := w.Header()
294	origin := r.Header.Get("Origin")
295
296	// Always set Vary, see https://github.com/rs/cors/issues/10
297	headers.Add("Vary", "Origin")
298	if origin == "" {
299		c.logf("Actual request no headers added: missing origin")
300		return
301	}
302	if !c.isOriginAllowed(r, origin) {
303		c.logf("Actual request no headers added: origin '%s' not allowed", origin)
304		return
305	}
306
307	// Note that spec does define a way to specifically disallow a simple method like GET or
308	// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
309	// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
310	// We think it's a nice feature to be able to have control on those methods though.
311	if !c.isMethodAllowed(r.Method) {
312		c.logf("Actual request no headers added: method '%s' not allowed", r.Method)
313
314		return
315	}
316	if c.allowedOriginsAll {
317		headers.Set("Access-Control-Allow-Origin", "*")
318	} else {
319		headers.Set("Access-Control-Allow-Origin", origin)
320	}
321	if len(c.exposedHeaders) > 0 {
322		headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
323	}
324	if c.allowCredentials {
325		headers.Set("Access-Control-Allow-Credentials", "true")
326	}
327	c.logf("Actual response added headers: %v", headers)
328}
329
330// convenience method. checks if a logger is set.
331func (c *Cors) logf(format string, a ...interface{}) {
332	if c.Log != nil {
333		c.Log.Printf(format, a...)
334	}
335}
336
337// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
338// on the endpoint
339func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
340	if c.allowOriginFunc != nil {
341		return c.allowOriginFunc(r, origin)
342	}
343	if c.allowedOriginsAll {
344		return true
345	}
346	origin = strings.ToLower(origin)
347	for _, o := range c.allowedOrigins {
348		if o == origin {
349			return true
350		}
351	}
352	for _, w := range c.allowedWOrigins {
353		if w.match(origin) {
354			return true
355		}
356	}
357	return false
358}
359
360// isMethodAllowed checks if a given method can be used as part of a cross-domain request
361// on the endpoint
362func (c *Cors) isMethodAllowed(method string) bool {
363	if len(c.allowedMethods) == 0 {
364		// If no method allowed, always return false, even for preflight request
365		return false
366	}
367	method = strings.ToUpper(method)
368	if method == http.MethodOptions {
369		// Always allow preflight requests
370		return true
371	}
372	for _, m := range c.allowedMethods {
373		if m == method {
374			return true
375		}
376	}
377	return false
378}
379
380// areHeadersAllowed checks if a given list of headers are allowed to used within
381// a cross-domain request.
382func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
383	if c.allowedHeadersAll || len(requestedHeaders) == 0 {
384		return true
385	}
386	for _, header := range requestedHeaders {
387		header = http.CanonicalHeaderKey(header)
388		found := false
389		for _, h := range c.allowedHeaders {
390			if h == header {
391				found = true
392				break
393			}
394		}
395		if !found {
396			return false
397		}
398	}
399	return true
400}
401