1# -*- coding: utf-8 -*-
2"""
3    core
4    ~~~~
5    Core functionality shared between the extension and the decorator.
6
7    :copyright: (c) 2016 by Cory Dolphin.
8    :license: MIT, see LICENSE for more details.
9"""
10import re
11import logging
12try:
13    # on python 3
14    from collections.abc import Iterable
15except ImportError:
16    # on python 2.7 and pypy
17    from collections import Iterable
18from datetime import timedelta
19from six import string_types
20from flask import request, current_app
21from werkzeug.datastructures import Headers, MultiDict
22
23LOG = logging.getLogger(__name__)
24
25# Response Headers
26ACL_ORIGIN = 'Access-Control-Allow-Origin'
27ACL_METHODS = 'Access-Control-Allow-Methods'
28ACL_ALLOW_HEADERS = 'Access-Control-Allow-Headers'
29ACL_EXPOSE_HEADERS = 'Access-Control-Expose-Headers'
30ACL_CREDENTIALS = 'Access-Control-Allow-Credentials'
31ACL_MAX_AGE = 'Access-Control-Max-Age'
32
33# Request Header
34ACL_REQUEST_METHOD = 'Access-Control-Request-Method'
35ACL_REQUEST_HEADERS = 'Access-Control-Request-Headers'
36
37ALL_METHODS = ['GET', 'HEAD', 'POST', 'OPTIONS', 'PUT', 'PATCH', 'DELETE']
38CONFIG_OPTIONS = ['CORS_ORIGINS', 'CORS_METHODS', 'CORS_ALLOW_HEADERS',
39                  'CORS_EXPOSE_HEADERS', 'CORS_SUPPORTS_CREDENTIALS',
40                  'CORS_MAX_AGE', 'CORS_SEND_WILDCARD',
41                  'CORS_AUTOMATIC_OPTIONS', 'CORS_VARY_HEADER',
42                  'CORS_RESOURCES', 'CORS_INTERCEPT_EXCEPTIONS',
43                  'CORS_ALWAYS_SEND']
44# Attribute added to request object by decorator to indicate that CORS
45# was evaluated, in case the decorator and extension are both applied
46# to a view.
47FLASK_CORS_EVALUATED = '_FLASK_CORS_EVALUATED'
48
49# Strange, but this gets the type of a compiled regex, which is otherwise not
50# exposed in a public API.
51RegexObject = type(re.compile(''))
52DEFAULT_OPTIONS = dict(origins='*',
53                       methods=ALL_METHODS,
54                       allow_headers='*',
55                       expose_headers=None,
56                       supports_credentials=False,
57                       max_age=None,
58                       send_wildcard=False,
59                       automatic_options=True,
60                       vary_header=True,
61                       resources=r'/*',
62                       intercept_exceptions=True,
63                       always_send=True)
64
65
66def parse_resources(resources):
67    if isinstance(resources, dict):
68        # To make the API more consistent with the decorator, allow a
69        # resource of '*', which is not actually a valid regexp.
70        resources = [(re_fix(k), v) for k, v in resources.items()]
71
72        # Sort by regex length to provide consistency of matching and
73        # to provide a proxy for specificity of match. E.G. longer
74        # regular expressions are tried first.
75        def pattern_length(pair):
76            maybe_regex, _ = pair
77            return len(get_regexp_pattern(maybe_regex))
78
79        return sorted(resources,
80                      key=pattern_length,
81                      reverse=True)
82
83    elif isinstance(resources, string_types):
84        return [(re_fix(resources), {})]
85
86    elif isinstance(resources, Iterable):
87        return [(re_fix(r), {}) for r in resources]
88
89    # Type of compiled regex is not part of the public API. Test for this
90    # at runtime.
91    elif isinstance(resources,  RegexObject):
92        return [(re_fix(resources), {})]
93
94    else:
95        raise ValueError("Unexpected value for resources argument.")
96
97
98def get_regexp_pattern(regexp):
99    """
100    Helper that returns regexp pattern from given value.
101
102    :param regexp: regular expression to stringify
103    :type regexp: _sre.SRE_Pattern or str
104    :returns: string representation of given regexp pattern
105    :rtype: str
106    """
107    try:
108        return regexp.pattern
109    except AttributeError:
110        return str(regexp)
111
112
113def get_cors_origins(options, request_origin):
114    origins = options.get('origins')
115    wildcard = r'.*' in origins
116
117    # If the Origin header is not present terminate this set of steps.
118    # The request is outside the scope of this specification.-- W3Spec
119    if request_origin:
120        LOG.debug("CORS request received with 'Origin' %s", request_origin)
121
122        # If the allowed origins is an asterisk or 'wildcard', always match
123        if wildcard and options.get('send_wildcard'):
124            LOG.debug("Allowed origins are set to '*'. Sending wildcard CORS header.")
125            return ['*']
126        # If the value of the Origin header is a case-sensitive match
127        # for any of the values in list of origins
128        elif try_match_any(request_origin, origins):
129            LOG.debug("The request's Origin header matches. Sending CORS headers.", )
130            # Add a single Access-Control-Allow-Origin header, with either
131            # the value of the Origin header or the string "*" as value.
132            # -- W3Spec
133            return [request_origin]
134        else:
135            LOG.debug("The request's Origin header does not match any of allowed origins.")
136            return None
137
138
139    elif options.get('always_send'):
140        if wildcard:
141            # If wildcard is in the origins, even if 'send_wildcard' is False,
142            # simply send the wildcard. Unless supports_credentials is True,
143            # since that is forbidded by the spec..
144            # It is the most-likely to be correct thing to do (the only other
145            # option is to return nothing, which  almost certainly not what
146            # the developer wants if the '*' origin was specified.
147            if options.get('supports_credentials'):
148                return None
149            else:
150                return ['*']
151        else:
152            # Return all origins that are not regexes.
153            return sorted([o for o in origins if not probably_regex(o)])
154
155    # Terminate these steps, return the original request untouched.
156    else:
157        LOG.debug("The request did not contain an 'Origin' header. This means the browser or client did not request CORS, ensure the Origin Header is set.")
158        return None
159
160
161def get_allow_headers(options, acl_request_headers):
162    if acl_request_headers:
163        request_headers = [h.strip() for h in acl_request_headers.split(',')]
164
165        # any header that matches in the allow_headers
166        matching_headers = filter(
167            lambda h: try_match_any(h, options.get('allow_headers')),
168            request_headers
169        )
170
171        return ', '.join(sorted(matching_headers))
172
173    return None
174
175
176def get_cors_headers(options, request_headers, request_method):
177    origins_to_set = get_cors_origins(options, request_headers.get('Origin'))
178    headers = MultiDict()
179
180    if not origins_to_set:  # CORS is not enabled for this route
181        return headers
182
183    for origin in origins_to_set:
184        headers.add(ACL_ORIGIN, origin)
185
186    headers[ACL_EXPOSE_HEADERS] = options.get('expose_headers')
187
188    if options.get('supports_credentials'):
189        headers[ACL_CREDENTIALS] = 'true'  # case sensative
190
191    # This is a preflight request
192    # http://www.w3.org/TR/cors/#resource-preflight-requests
193    if request_method == 'OPTIONS':
194        acl_request_method = request_headers.get(ACL_REQUEST_METHOD, '').upper()
195
196        # If there is no Access-Control-Request-Method header or if parsing
197        # failed, do not set any additional headers
198        if acl_request_method and acl_request_method in options.get('methods'):
199
200            # If method is not a case-sensitive match for any of the values in
201            # list of methods do not set any additional headers and terminate
202            # this set of steps.
203            headers[ACL_ALLOW_HEADERS] = get_allow_headers(options, request_headers.get(ACL_REQUEST_HEADERS))
204            headers[ACL_MAX_AGE] = options.get('max_age')
205            headers[ACL_METHODS] = options.get('methods')
206        else:
207            LOG.info("The request's Access-Control-Request-Method header does not match allowed methods. CORS headers will not be applied.")
208
209    # http://www.w3.org/TR/cors/#resource-implementation
210    if options.get('vary_header'):
211        # Only set header if the origin returned will vary dynamically,
212        # i.e. if we are not returning an asterisk, and there are multiple
213        # origins that can be matched.
214        if headers[ACL_ORIGIN] == '*':
215            pass
216        elif (len(options.get('origins')) > 1 or
217              len(origins_to_set) > 1 or
218              any(map(probably_regex, options.get('origins')))):
219            headers.add('Vary', 'Origin')
220
221    return MultiDict((k, v) for k, v in headers.items() if v)
222
223
224def set_cors_headers(resp, options):
225    """
226    Performs the actual evaluation of Flas-CORS options and actually
227    modifies the response object.
228
229    This function is used both in the decorator and the after_request
230    callback
231    """
232
233    # If CORS has already been evaluated via the decorator, skip
234    if hasattr(resp, FLASK_CORS_EVALUATED):
235        LOG.debug('CORS have been already evaluated, skipping')
236        return resp
237
238    # Some libraries, like OAuthlib, set resp.headers to non Multidict
239    # objects (Werkzeug Headers work as well). This is a problem because
240    # headers allow repeated values.
241    if (not isinstance(resp.headers, Headers)
242           and not isinstance(resp.headers, MultiDict)):
243        resp.headers = MultiDict(resp.headers)
244
245    headers_to_set = get_cors_headers(options, request.headers, request.method)
246
247    LOG.debug('Settings CORS headers: %s', str(headers_to_set))
248
249    for k, v in headers_to_set.items():
250        resp.headers.add(k, v)
251
252    return resp
253
254def probably_regex(maybe_regex):
255    if isinstance(maybe_regex, RegexObject):
256        return True
257    else:
258        common_regex_chars = ['*', '\\', ']', '?', '$', '^', '[', ']', '(', ')']
259        # Use common characters used in regular expressions as a proxy
260        # for if this string is in fact a regex.
261        return any((c in maybe_regex for c in common_regex_chars))
262
263def re_fix(reg):
264    """
265        Replace the invalid regex r'*' with the valid, wildcard regex r'/.*' to
266        enable the CORS app extension to have a more user friendly api.
267    """
268    return r'.*' if reg == r'*' else reg
269
270
271def try_match_any(inst, patterns):
272    return any(try_match(inst, pattern) for pattern in patterns)
273
274
275def try_match(request_origin, maybe_regex):
276    """Safely attempts to match a pattern or string to a request origin."""
277    if isinstance(maybe_regex, RegexObject):
278        return re.match(maybe_regex, request_origin)
279    elif probably_regex(maybe_regex):
280        return re.match(maybe_regex, request_origin, flags=re.IGNORECASE)
281    else:
282        try:
283            return request_origin.lower() == maybe_regex.lower()
284        except AttributeError:
285            return request_origin == maybe_regex
286
287
288def get_cors_options(appInstance, *dicts):
289    """
290    Compute CORS options for an application by combining the DEFAULT_OPTIONS,
291    the app's configuration-specified options and any dictionaries passed. The
292    last specified option wins.
293    """
294    options = DEFAULT_OPTIONS.copy()
295    options.update(get_app_kwarg_dict(appInstance))
296    if dicts:
297        for d in dicts:
298            options.update(d)
299
300    return serialize_options(options)
301
302
303def get_app_kwarg_dict(appInstance=None):
304    """Returns the dictionary of CORS specific app configurations."""
305    app = (appInstance or current_app)
306
307    # In order to support blueprints which do not have a config attribute
308    app_config = getattr(app, 'config', {})
309
310    return {
311        k.lower().replace('cors_', ''): app_config.get(k)
312        for k in CONFIG_OPTIONS
313        if app_config.get(k) is not None
314    }
315
316
317def flexible_str(obj):
318    """
319    A more flexible str function which intelligently handles stringifying
320    strings, lists and other iterables. The results are lexographically sorted
321    to ensure generated responses are consistent when iterables such as Set
322    are used.
323    """
324    if obj is None:
325        return None
326    elif(not isinstance(obj, string_types)
327            and isinstance(obj, Iterable)):
328        return ', '.join(str(item) for item in sorted(obj))
329    else:
330        return str(obj)
331
332
333def serialize_option(options_dict, key, upper=False):
334    if key in options_dict:
335        value = flexible_str(options_dict[key])
336        options_dict[key] = value.upper() if upper else value
337
338
339def ensure_iterable(inst):
340    """
341    Wraps scalars or string types as a list, or returns the iterable instance.
342    """
343    if isinstance(inst, string_types):
344        return [inst]
345    elif not isinstance(inst, Iterable):
346        return [inst]
347    else:
348        return inst
349
350def sanitize_regex_param(param):
351    return [re_fix(x) for x in ensure_iterable(param)]
352
353
354def serialize_options(opts):
355    """
356    A helper method to serialize and processes the options dictionary.
357    """
358    options = (opts or {}).copy()
359
360    for key in opts.keys():
361        if key not in DEFAULT_OPTIONS:
362             LOG.warning("Unknown option passed to Flask-CORS: %s", key)
363
364    # Ensure origins is a list of allowed origins with at least one entry.
365    options['origins'] = sanitize_regex_param(options.get('origins'))
366    options['allow_headers'] = sanitize_regex_param(options.get('allow_headers'))
367
368    # This is expressly forbidden by the spec. Raise a value error so people
369    # don't get burned in production.
370    if r'.*' in options['origins'] and options['supports_credentials'] and options['send_wildcard']:
371        raise ValueError("Cannot use supports_credentials in conjunction with"
372                         "an origin string of '*'. See: "
373                         "http://www.w3.org/TR/cors/#resource-requests")
374
375
376
377    serialize_option(options, 'expose_headers')
378    serialize_option(options, 'methods', upper=True)
379
380    if isinstance(options.get('max_age'), timedelta):
381        options['max_age'] = str(int(options['max_age'].total_seconds()))
382
383    return options
384