1""" 2Basic HTTP Proxy 3================ 4 5.. autoclass:: ProxyMiddleware 6 7:copyright: 2007 Pallets 8:license: BSD-3-Clause 9""" 10import typing as t 11from http import client 12 13from ..datastructures import EnvironHeaders 14from ..http import is_hop_by_hop_header 15from ..urls import url_parse 16from ..urls import url_quote 17from ..wsgi import get_input_stream 18 19if t.TYPE_CHECKING: 20 from _typeshed.wsgi import StartResponse 21 from _typeshed.wsgi import WSGIApplication 22 from _typeshed.wsgi import WSGIEnvironment 23 24 25class ProxyMiddleware: 26 """Proxy requests under a path to an external server, routing other 27 requests to the app. 28 29 This middleware can only proxy HTTP requests, as HTTP is the only 30 protocol handled by the WSGI server. Other protocols, such as 31 WebSocket requests, cannot be proxied at this layer. This should 32 only be used for development, in production a real proxy server 33 should be used. 34 35 The middleware takes a dict mapping a path prefix to a dict 36 describing the host to be proxied to:: 37 38 app = ProxyMiddleware(app, { 39 "/static/": { 40 "target": "http://127.0.0.1:5001/", 41 } 42 }) 43 44 Each host has the following options: 45 46 ``target``: 47 The target URL to dispatch to. This is required. 48 ``remove_prefix``: 49 Whether to remove the prefix from the URL before dispatching it 50 to the target. The default is ``False``. 51 ``host``: 52 ``"<auto>"`` (default): 53 The host header is automatically rewritten to the URL of the 54 target. 55 ``None``: 56 The host header is unmodified from the client request. 57 Any other value: 58 The host header is overwritten with the value. 59 ``headers``: 60 A dictionary of headers to be sent with the request to the 61 target. The default is ``{}``. 62 ``ssl_context``: 63 A :class:`ssl.SSLContext` defining how to verify requests if the 64 target is HTTPS. The default is ``None``. 65 66 In the example above, everything under ``"/static/"`` is proxied to 67 the server on port 5001. The host header is rewritten to the target, 68 and the ``"/static/"`` prefix is removed from the URLs. 69 70 :param app: The WSGI application to wrap. 71 :param targets: Proxy target configurations. See description above. 72 :param chunk_size: Size of chunks to read from input stream and 73 write to target. 74 :param timeout: Seconds before an operation to a target fails. 75 76 .. versionadded:: 0.14 77 """ 78 79 def __init__( 80 self, 81 app: "WSGIApplication", 82 targets: t.Mapping[str, t.Dict[str, t.Any]], 83 chunk_size: int = 2 << 13, 84 timeout: int = 10, 85 ) -> None: 86 def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: 87 opts.setdefault("remove_prefix", False) 88 opts.setdefault("host", "<auto>") 89 opts.setdefault("headers", {}) 90 opts.setdefault("ssl_context", None) 91 return opts 92 93 self.app = app 94 self.targets = { 95 f"/{k.strip('/')}/": _set_defaults(v) for k, v in targets.items() 96 } 97 self.chunk_size = chunk_size 98 self.timeout = timeout 99 100 def proxy_to( 101 self, opts: t.Dict[str, t.Any], path: str, prefix: str 102 ) -> "WSGIApplication": 103 target = url_parse(opts["target"]) 104 host = t.cast(str, target.ascii_host) 105 106 def application( 107 environ: "WSGIEnvironment", start_response: "StartResponse" 108 ) -> t.Iterable[bytes]: 109 headers = list(EnvironHeaders(environ).items()) 110 headers[:] = [ 111 (k, v) 112 for k, v in headers 113 if not is_hop_by_hop_header(k) 114 and k.lower() not in ("content-length", "host") 115 ] 116 headers.append(("Connection", "close")) 117 118 if opts["host"] == "<auto>": 119 headers.append(("Host", host)) 120 elif opts["host"] is None: 121 headers.append(("Host", environ["HTTP_HOST"])) 122 else: 123 headers.append(("Host", opts["host"])) 124 125 headers.extend(opts["headers"].items()) 126 remote_path = path 127 128 if opts["remove_prefix"]: 129 remote_path = remote_path[len(prefix) :].lstrip("/") 130 remote_path = f"{target.path.rstrip('/')}/{remote_path}" 131 132 content_length = environ.get("CONTENT_LENGTH") 133 chunked = False 134 135 if content_length not in ("", None): 136 headers.append(("Content-Length", content_length)) # type: ignore 137 elif content_length is not None: 138 headers.append(("Transfer-Encoding", "chunked")) 139 chunked = True 140 141 try: 142 if target.scheme == "http": 143 con = client.HTTPConnection( 144 host, target.port or 80, timeout=self.timeout 145 ) 146 elif target.scheme == "https": 147 con = client.HTTPSConnection( 148 host, 149 target.port or 443, 150 timeout=self.timeout, 151 context=opts["ssl_context"], 152 ) 153 else: 154 raise RuntimeError( 155 "Target scheme must be 'http' or 'https', got" 156 f" {target.scheme!r}." 157 ) 158 159 con.connect() 160 remote_url = url_quote(remote_path) 161 querystring = environ["QUERY_STRING"] 162 163 if querystring: 164 remote_url = f"{remote_url}?{querystring}" 165 166 con.putrequest(environ["REQUEST_METHOD"], remote_url, skip_host=True) 167 168 for k, v in headers: 169 if k.lower() == "connection": 170 v = "close" 171 172 con.putheader(k, v) 173 174 con.endheaders() 175 stream = get_input_stream(environ) 176 177 while True: 178 data = stream.read(self.chunk_size) 179 180 if not data: 181 break 182 183 if chunked: 184 con.send(b"%x\r\n%s\r\n" % (len(data), data)) 185 else: 186 con.send(data) 187 188 resp = con.getresponse() 189 except OSError: 190 from ..exceptions import BadGateway 191 192 return BadGateway()(environ, start_response) 193 194 start_response( 195 f"{resp.status} {resp.reason}", 196 [ 197 (k.title(), v) 198 for k, v in resp.getheaders() 199 if not is_hop_by_hop_header(k) 200 ], 201 ) 202 203 def read() -> t.Iterator[bytes]: 204 while True: 205 try: 206 data = resp.read(self.chunk_size) 207 except OSError: 208 break 209 210 if not data: 211 break 212 213 yield data 214 215 return read() 216 217 return application 218 219 def __call__( 220 self, environ: "WSGIEnvironment", start_response: "StartResponse" 221 ) -> t.Iterable[bytes]: 222 path = environ["PATH_INFO"] 223 app = self.app 224 225 for prefix, opts in self.targets.items(): 226 if path.startswith(prefix): 227 app = self.proxy_to(opts, path, prefix) 228 break 229 230 return app(environ, start_response) 231