1"""
2Basic HTTP Proxy
3================
4
5.. autoclass:: ProxyMiddleware
6
7:copyright: 2007 Pallets
8:license: BSD-3-Clause
9"""
10import typing as t
11from http import client
12
13from ..datastructures import EnvironHeaders
14from ..http import is_hop_by_hop_header
15from ..urls import url_parse
16from ..urls import url_quote
17from ..wsgi import get_input_stream
18
19if t.TYPE_CHECKING:
20    from _typeshed.wsgi import StartResponse
21    from _typeshed.wsgi import WSGIApplication
22    from _typeshed.wsgi import WSGIEnvironment
23
24
25class ProxyMiddleware:
26    """Proxy requests under a path to an external server, routing other
27    requests to the app.
28
29    This middleware can only proxy HTTP requests, as HTTP is the only
30    protocol handled by the WSGI server. Other protocols, such as
31    WebSocket requests, cannot be proxied at this layer. This should
32    only be used for development, in production a real proxy server
33    should be used.
34
35    The middleware takes a dict mapping a path prefix to a dict
36    describing the host to be proxied to::
37
38        app = ProxyMiddleware(app, {
39            "/static/": {
40                "target": "http://127.0.0.1:5001/",
41            }
42        })
43
44    Each host has the following options:
45
46    ``target``:
47        The target URL to dispatch to. This is required.
48    ``remove_prefix``:
49        Whether to remove the prefix from the URL before dispatching it
50        to the target. The default is ``False``.
51    ``host``:
52        ``"<auto>"`` (default):
53            The host header is automatically rewritten to the URL of the
54            target.
55        ``None``:
56            The host header is unmodified from the client request.
57        Any other value:
58            The host header is overwritten with the value.
59    ``headers``:
60        A dictionary of headers to be sent with the request to the
61        target. The default is ``{}``.
62    ``ssl_context``:
63        A :class:`ssl.SSLContext` defining how to verify requests if the
64        target is HTTPS. The default is ``None``.
65
66    In the example above, everything under ``"/static/"`` is proxied to
67    the server on port 5001. The host header is rewritten to the target,
68    and the ``"/static/"`` prefix is removed from the URLs.
69
70    :param app: The WSGI application to wrap.
71    :param targets: Proxy target configurations. See description above.
72    :param chunk_size: Size of chunks to read from input stream and
73        write to target.
74    :param timeout: Seconds before an operation to a target fails.
75
76    .. versionadded:: 0.14
77    """
78
79    def __init__(
80        self,
81        app: "WSGIApplication",
82        targets: t.Mapping[str, t.Dict[str, t.Any]],
83        chunk_size: int = 2 << 13,
84        timeout: int = 10,
85    ) -> None:
86        def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
87            opts.setdefault("remove_prefix", False)
88            opts.setdefault("host", "<auto>")
89            opts.setdefault("headers", {})
90            opts.setdefault("ssl_context", None)
91            return opts
92
93        self.app = app
94        self.targets = {
95            f"/{k.strip('/')}/": _set_defaults(v) for k, v in targets.items()
96        }
97        self.chunk_size = chunk_size
98        self.timeout = timeout
99
100    def proxy_to(
101        self, opts: t.Dict[str, t.Any], path: str, prefix: str
102    ) -> "WSGIApplication":
103        target = url_parse(opts["target"])
104        host = t.cast(str, target.ascii_host)
105
106        def application(
107            environ: "WSGIEnvironment", start_response: "StartResponse"
108        ) -> t.Iterable[bytes]:
109            headers = list(EnvironHeaders(environ).items())
110            headers[:] = [
111                (k, v)
112                for k, v in headers
113                if not is_hop_by_hop_header(k)
114                and k.lower() not in ("content-length", "host")
115            ]
116            headers.append(("Connection", "close"))
117
118            if opts["host"] == "<auto>":
119                headers.append(("Host", host))
120            elif opts["host"] is None:
121                headers.append(("Host", environ["HTTP_HOST"]))
122            else:
123                headers.append(("Host", opts["host"]))
124
125            headers.extend(opts["headers"].items())
126            remote_path = path
127
128            if opts["remove_prefix"]:
129                remote_path = remote_path[len(prefix) :].lstrip("/")
130                remote_path = f"{target.path.rstrip('/')}/{remote_path}"
131
132            content_length = environ.get("CONTENT_LENGTH")
133            chunked = False
134
135            if content_length not in ("", None):
136                headers.append(("Content-Length", content_length))  # type: ignore
137            elif content_length is not None:
138                headers.append(("Transfer-Encoding", "chunked"))
139                chunked = True
140
141            try:
142                if target.scheme == "http":
143                    con = client.HTTPConnection(
144                        host, target.port or 80, timeout=self.timeout
145                    )
146                elif target.scheme == "https":
147                    con = client.HTTPSConnection(
148                        host,
149                        target.port or 443,
150                        timeout=self.timeout,
151                        context=opts["ssl_context"],
152                    )
153                else:
154                    raise RuntimeError(
155                        "Target scheme must be 'http' or 'https', got"
156                        f" {target.scheme!r}."
157                    )
158
159                con.connect()
160                remote_url = url_quote(remote_path)
161                querystring = environ["QUERY_STRING"]
162
163                if querystring:
164                    remote_url = f"{remote_url}?{querystring}"
165
166                con.putrequest(environ["REQUEST_METHOD"], remote_url, skip_host=True)
167
168                for k, v in headers:
169                    if k.lower() == "connection":
170                        v = "close"
171
172                    con.putheader(k, v)
173
174                con.endheaders()
175                stream = get_input_stream(environ)
176
177                while True:
178                    data = stream.read(self.chunk_size)
179
180                    if not data:
181                        break
182
183                    if chunked:
184                        con.send(b"%x\r\n%s\r\n" % (len(data), data))
185                    else:
186                        con.send(data)
187
188                resp = con.getresponse()
189            except OSError:
190                from ..exceptions import BadGateway
191
192                return BadGateway()(environ, start_response)
193
194            start_response(
195                f"{resp.status} {resp.reason}",
196                [
197                    (k.title(), v)
198                    for k, v in resp.getheaders()
199                    if not is_hop_by_hop_header(k)
200                ],
201            )
202
203            def read() -> t.Iterator[bytes]:
204                while True:
205                    try:
206                        data = resp.read(self.chunk_size)
207                    except OSError:
208                        break
209
210                    if not data:
211                        break
212
213                    yield data
214
215            return read()
216
217        return application
218
219    def __call__(
220        self, environ: "WSGIEnvironment", start_response: "StartResponse"
221    ) -> t.Iterable[bytes]:
222        path = environ["PATH_INFO"]
223        app = self.app
224
225        for prefix, opts in self.targets.items():
226            if path.startswith(prefix):
227                app = self.proxy_to(opts, path, prefix)
228                break
229
230        return app(environ, start_response)
231