1import re
2from typing import Any, Optional
3from urllib.parse import ParseResult, urlparse
4
5from django.http import HttpRequest, HttpResponse
6from django.utils.cache import patch_vary_headers
7from django.utils.deprecation import MiddlewareMixin
8
9from corsheaders.conf import conf
10from corsheaders.signals import check_request_enabled
11
12ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"
13ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"
14ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"
15ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"
16ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"
17ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"
18
19
20class CorsPostCsrfMiddleware(MiddlewareMixin):
21    def _https_referer_replace_reverse(self, request: HttpRequest) -> None:
22        """
23        Put the HTTP_REFERER back to its original value and delete the
24        temporary storage
25        """
26        if conf.CORS_REPLACE_HTTPS_REFERER and "ORIGINAL_HTTP_REFERER" in request.META:
27            http_referer = request.META["ORIGINAL_HTTP_REFERER"]
28            request.META["HTTP_REFERER"] = http_referer
29            del request.META["ORIGINAL_HTTP_REFERER"]
30
31    def process_request(self, request: HttpRequest) -> None:
32        self._https_referer_replace_reverse(request)
33        return None
34
35    def process_view(
36        self,
37        request: HttpRequest,
38        callback: Any,
39        callback_args: Any,
40        callback_kwargs: Any,
41    ) -> None:
42        self._https_referer_replace_reverse(request)
43        return None
44
45
46class CorsMiddleware(MiddlewareMixin):
47    def _https_referer_replace(self, request: HttpRequest) -> None:
48        """
49        When https is enabled, django CSRF checking includes referer checking
50        which breaks when using CORS. This function updates the HTTP_REFERER
51        header to make sure it matches HTTP_HOST, provided that our cors logic
52        succeeds
53        """
54        origin = request.META.get("HTTP_ORIGIN")
55
56        if (
57            request.is_secure()
58            and origin
59            and "ORIGINAL_HTTP_REFERER" not in request.META
60        ):
61
62            url = urlparse(origin)
63            if (
64                not conf.CORS_ALLOW_ALL_ORIGINS
65                and not self.origin_found_in_white_lists(origin, url)
66            ):
67                return
68
69            try:
70                http_referer = request.META["HTTP_REFERER"]
71                http_host = "https://%s/" % request.META["HTTP_HOST"]
72                request.META = request.META.copy()
73                request.META["ORIGINAL_HTTP_REFERER"] = http_referer
74                request.META["HTTP_REFERER"] = http_host
75            except KeyError:
76                pass
77
78    def process_request(self, request: HttpRequest) -> Optional[HttpResponse]:
79        """
80        If CORS preflight header, then create an
81        empty body response (200 OK) and return it
82
83        Django won't bother calling any other request
84        view/exception middleware along with the requested view;
85        it will call any response middlewares
86        """
87        request._cors_enabled = self.is_enabled(request)
88        if request._cors_enabled:
89            if conf.CORS_REPLACE_HTTPS_REFERER:
90                self._https_referer_replace(request)
91
92            if (
93                request.method == "OPTIONS"
94                and "HTTP_ACCESS_CONTROL_REQUEST_METHOD" in request.META
95            ):
96                response = HttpResponse()
97                response["Content-Length"] = "0"
98                return response
99        return None
100
101    def process_view(
102        self,
103        request: HttpRequest,
104        callback: Any,
105        callback_args: Any,
106        callback_kwargs: Any,
107    ) -> None:
108        """
109        Do the referer replacement here as well
110        """
111        if request._cors_enabled and conf.CORS_REPLACE_HTTPS_REFERER:
112            self._https_referer_replace(request)
113        return None
114
115    def process_response(
116        self, request: HttpRequest, response: HttpResponse
117    ) -> HttpResponse:
118        """
119        Add the respective CORS headers
120        """
121        enabled = getattr(request, "_cors_enabled", None)
122        if enabled is None:
123            enabled = self.is_enabled(request)
124
125        if not enabled:
126            return response
127
128        patch_vary_headers(response, ["Origin"])
129
130        origin = request.META.get("HTTP_ORIGIN")
131        if not origin:
132            return response
133
134        try:
135            url = urlparse(origin)
136        except ValueError:
137            return response
138
139        if conf.CORS_ALLOW_CREDENTIALS:
140            response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
141
142        if (
143            not conf.CORS_ALLOW_ALL_ORIGINS
144            and not self.origin_found_in_white_lists(origin, url)
145            and not self.check_signal(request)
146        ):
147            return response
148
149        if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS:
150            response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*"
151        else:
152            response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
153
154        if len(conf.CORS_EXPOSE_HEADERS):
155            response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join(
156                conf.CORS_EXPOSE_HEADERS
157            )
158
159        if request.method == "OPTIONS":
160            response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS)
161            response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS)
162            if conf.CORS_PREFLIGHT_MAX_AGE:
163                response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE)
164
165        return response
166
167    def origin_found_in_white_lists(self, origin: str, url: ParseResult) -> bool:
168        return (
169            (origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS)
170            or self._url_in_whitelist(url)
171            or self.regex_domain_match(origin)
172        )
173
174    def regex_domain_match(self, origin: str) -> bool:
175        return any(
176            re.match(domain_pattern, origin)
177            for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES
178        )
179
180    def is_enabled(self, request: HttpRequest) -> bool:
181        return bool(
182            re.match(conf.CORS_URLS_REGEX, request.path_info)
183        ) or self.check_signal(request)
184
185    def check_signal(self, request: HttpRequest) -> bool:
186        signal_responses = check_request_enabled.send(sender=None, request=request)
187        return any(return_value for function, return_value in signal_responses)
188
189    def _url_in_whitelist(self, url: ParseResult) -> bool:
190        origins = [urlparse(o) for o in conf.CORS_ALLOWED_ORIGINS]
191        return any(
192            origin.scheme == url.scheme and origin.netloc == url.netloc
193            for origin in origins
194        )
195