1"""
2WSGI Protocol Linter
3====================
4
5This module provides a middleware that performs sanity checks on the
6behavior of the WSGI server and application. It checks that the
7:pep:`3333` WSGI spec is properly implemented. It also warns on some
8common HTTP errors such as non-empty responses for 304 status codes.
9
10.. autoclass:: LintMiddleware
11
12:copyright: 2007 Pallets
13:license: BSD-3-Clause
14"""
15import typing as t
16from types import TracebackType
17from urllib.parse import urlparse
18from warnings import warn
19
20from ..datastructures import Headers
21from ..http import is_entity_header
22from ..wsgi import FileWrapper
23
24if t.TYPE_CHECKING:
25    from _typeshed.wsgi import StartResponse
26    from _typeshed.wsgi import WSGIApplication
27    from _typeshed.wsgi import WSGIEnvironment
28
29
30class WSGIWarning(Warning):
31    """Warning class for WSGI warnings."""
32
33
34class HTTPWarning(Warning):
35    """Warning class for HTTP warnings."""
36
37
38def check_type(context: str, obj: object, need: t.Type = str) -> None:
39    if type(obj) is not need:
40        warn(
41            f"{context!r} requires {need.__name__!r}, got {type(obj).__name__!r}.",
42            WSGIWarning,
43            stacklevel=3,
44        )
45
46
47class InputStream:
48    def __init__(self, stream: t.IO[bytes]) -> None:
49        self._stream = stream
50
51    def read(self, *args: t.Any) -> bytes:
52        if len(args) == 0:
53            warn(
54                "WSGI does not guarantee an EOF marker on the input stream, thus making"
55                " calls to 'wsgi.input.read()' unsafe. Conforming servers may never"
56                " return from this call.",
57                WSGIWarning,
58                stacklevel=2,
59            )
60        elif len(args) != 1:
61            warn(
62                "Too many parameters passed to 'wsgi.input.read()'.",
63                WSGIWarning,
64                stacklevel=2,
65            )
66        return self._stream.read(*args)
67
68    def readline(self, *args: t.Any) -> bytes:
69        if len(args) == 0:
70            warn(
71                "Calls to 'wsgi.input.readline()' without arguments are unsafe. Use"
72                " 'wsgi.input.read()' instead.",
73                WSGIWarning,
74                stacklevel=2,
75            )
76        elif len(args) == 1:
77            warn(
78                "'wsgi.input.readline()' was called with a size hint. WSGI does not"
79                " support this, although it's available on all major servers.",
80                WSGIWarning,
81                stacklevel=2,
82            )
83        else:
84            raise TypeError("Too many arguments passed to 'wsgi.input.readline()'.")
85        return self._stream.readline(*args)
86
87    def __iter__(self) -> t.Iterator[bytes]:
88        try:
89            return iter(self._stream)
90        except TypeError:
91            warn("'wsgi.input' is not iterable.", WSGIWarning, stacklevel=2)
92            return iter(())
93
94    def close(self) -> None:
95        warn("The application closed the input stream!", WSGIWarning, stacklevel=2)
96        self._stream.close()
97
98
99class ErrorStream:
100    def __init__(self, stream: t.IO[str]) -> None:
101        self._stream = stream
102
103    def write(self, s: str) -> None:
104        check_type("wsgi.error.write()", s, str)
105        self._stream.write(s)
106
107    def flush(self) -> None:
108        self._stream.flush()
109
110    def writelines(self, seq: t.Iterable[str]) -> None:
111        for line in seq:
112            self.write(line)
113
114    def close(self) -> None:
115        warn("The application closed the error stream!", WSGIWarning, stacklevel=2)
116        self._stream.close()
117
118
119class GuardedWrite:
120    def __init__(self, write: t.Callable[[bytes], None], chunks: t.List[int]) -> None:
121        self._write = write
122        self._chunks = chunks
123
124    def __call__(self, s: bytes) -> None:
125        check_type("write()", s, bytes)
126        self._write(s)
127        self._chunks.append(len(s))
128
129
130class GuardedIterator:
131    def __init__(
132        self,
133        iterator: t.Iterable[bytes],
134        headers_set: t.Tuple[int, Headers],
135        chunks: t.List[int],
136    ) -> None:
137        self._iterator = iterator
138        self._next = iter(iterator).__next__
139        self.closed = False
140        self.headers_set = headers_set
141        self.chunks = chunks
142
143    def __iter__(self) -> "GuardedIterator":
144        return self
145
146    def __next__(self) -> bytes:
147        if self.closed:
148            warn("Iterated over closed 'app_iter'.", WSGIWarning, stacklevel=2)
149
150        rv = self._next()
151
152        if not self.headers_set:
153            warn(
154                "The application returned before it started the response.",
155                WSGIWarning,
156                stacklevel=2,
157            )
158
159        check_type("application iterator items", rv, bytes)
160        self.chunks.append(len(rv))
161        return rv
162
163    def close(self) -> None:
164        self.closed = True
165
166        if hasattr(self._iterator, "close"):
167            self._iterator.close()  # type: ignore
168
169        if self.headers_set:
170            status_code, headers = self.headers_set
171            bytes_sent = sum(self.chunks)
172            content_length = headers.get("content-length", type=int)
173
174            if status_code == 304:
175                for key, _value in headers:
176                    key = key.lower()
177                    if key not in ("expires", "content-location") and is_entity_header(
178                        key
179                    ):
180                        warn(
181                            f"Entity header {key!r} found in 304 response.", HTTPWarning
182                        )
183                if bytes_sent:
184                    warn("304 responses must not have a body.", HTTPWarning)
185            elif 100 <= status_code < 200 or status_code == 204:
186                if content_length != 0:
187                    warn(
188                        f"{status_code} responses must have an empty content length.",
189                        HTTPWarning,
190                    )
191                if bytes_sent:
192                    warn(f"{status_code} responses must not have a body.", HTTPWarning)
193            elif content_length is not None and content_length != bytes_sent:
194                warn(
195                    "Content-Length and the number of bytes sent to the"
196                    " client do not match.",
197                    WSGIWarning,
198                )
199
200    def __del__(self) -> None:
201        if not self.closed:
202            try:
203                warn(
204                    "Iterator was garbage collected before it was closed.", WSGIWarning
205                )
206            except Exception:
207                pass
208
209
210class LintMiddleware:
211    """Warns about common errors in the WSGI and HTTP behavior of the
212    server and wrapped application. Some of the issues it checks are:
213
214    -   invalid status codes
215    -   non-bytes sent to the WSGI server
216    -   strings returned from the WSGI application
217    -   non-empty conditional responses
218    -   unquoted etags
219    -   relative URLs in the Location header
220    -   unsafe calls to wsgi.input
221    -   unclosed iterators
222
223    Error information is emitted using the :mod:`warnings` module.
224
225    :param app: The WSGI application to wrap.
226
227    .. code-block:: python
228
229        from werkzeug.middleware.lint import LintMiddleware
230        app = LintMiddleware(app)
231    """
232
233    def __init__(self, app: "WSGIApplication") -> None:
234        self.app = app
235
236    def check_environ(self, environ: "WSGIEnvironment") -> None:
237        if type(environ) is not dict:
238            warn(
239                "WSGI environment is not a standard Python dict.",
240                WSGIWarning,
241                stacklevel=4,
242            )
243        for key in (
244            "REQUEST_METHOD",
245            "SERVER_NAME",
246            "SERVER_PORT",
247            "wsgi.version",
248            "wsgi.input",
249            "wsgi.errors",
250            "wsgi.multithread",
251            "wsgi.multiprocess",
252            "wsgi.run_once",
253        ):
254            if key not in environ:
255                warn(
256                    f"Required environment key {key!r} not found",
257                    WSGIWarning,
258                    stacklevel=3,
259                )
260        if environ["wsgi.version"] != (1, 0):
261            warn("Environ is not a WSGI 1.0 environ.", WSGIWarning, stacklevel=3)
262
263        script_name = environ.get("SCRIPT_NAME", "")
264        path_info = environ.get("PATH_INFO", "")
265
266        if script_name and script_name[0] != "/":
267            warn(
268                f"'SCRIPT_NAME' does not start with a slash: {script_name!r}",
269                WSGIWarning,
270                stacklevel=3,
271            )
272
273        if path_info and path_info[0] != "/":
274            warn(
275                f"'PATH_INFO' does not start with a slash: {path_info!r}",
276                WSGIWarning,
277                stacklevel=3,
278            )
279
280    def check_start_response(
281        self,
282        status: str,
283        headers: t.List[t.Tuple[str, str]],
284        exc_info: t.Optional[
285            t.Tuple[t.Type[BaseException], BaseException, TracebackType]
286        ],
287    ) -> t.Tuple[int, Headers]:
288        check_type("status", status, str)
289        status_code_str = status.split(None, 1)[0]
290
291        if len(status_code_str) != 3 or not status_code_str.isdigit():
292            warn("Status code must be three digits.", WSGIWarning, stacklevel=3)
293
294        if len(status) < 4 or status[3] != " ":
295            warn(
296                f"Invalid value for status {status!r}. Valid status strings are three"
297                " digits, a space and a status explanation.",
298                WSGIWarning,
299                stacklevel=3,
300            )
301
302        status_code = int(status_code_str)
303
304        if status_code < 100:
305            warn("Status code < 100 detected.", WSGIWarning, stacklevel=3)
306
307        if type(headers) is not list:
308            warn("Header list is not a list.", WSGIWarning, stacklevel=3)
309
310        for item in headers:
311            if type(item) is not tuple or len(item) != 2:
312                warn("Header items must be 2-item tuples.", WSGIWarning, stacklevel=3)
313            name, value = item
314            if type(name) is not str or type(value) is not str:
315                warn(
316                    "Header keys and values must be strings.", WSGIWarning, stacklevel=3
317                )
318            if name.lower() == "status":
319                warn(
320                    "The status header is not supported due to"
321                    " conflicts with the CGI spec.",
322                    WSGIWarning,
323                    stacklevel=3,
324                )
325
326        if exc_info is not None and not isinstance(exc_info, tuple):
327            warn("Invalid value for exc_info.", WSGIWarning, stacklevel=3)
328
329        headers = Headers(headers)
330        self.check_headers(headers)
331
332        return status_code, headers
333
334    def check_headers(self, headers: Headers) -> None:
335        etag = headers.get("etag")
336
337        if etag is not None:
338            if etag.startswith(("W/", "w/")):
339                if etag.startswith("w/"):
340                    warn(
341                        "Weak etag indicator should be upper case.",
342                        HTTPWarning,
343                        stacklevel=4,
344                    )
345
346                etag = etag[2:]
347
348            if not (etag[:1] == etag[-1:] == '"'):
349                warn("Unquoted etag emitted.", HTTPWarning, stacklevel=4)
350
351        location = headers.get("location")
352
353        if location is not None:
354            if not urlparse(location).netloc:
355                warn(
356                    "Absolute URLs required for location header.",
357                    HTTPWarning,
358                    stacklevel=4,
359                )
360
361    def check_iterator(self, app_iter: t.Iterable[bytes]) -> None:
362        if isinstance(app_iter, bytes):
363            warn(
364                "The application returned a bytestring. The response will send one"
365                " character at a time to the client, which will kill performance."
366                " Return a list or iterable instead.",
367                WSGIWarning,
368                stacklevel=3,
369            )
370
371    def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Iterable[bytes]:
372        if len(args) != 2:
373            warn("A WSGI app takes two arguments.", WSGIWarning, stacklevel=2)
374
375        if kwargs:
376            warn(
377                "A WSGI app does not take keyword arguments.", WSGIWarning, stacklevel=2
378            )
379
380        environ: "WSGIEnvironment" = args[0]
381        start_response: "StartResponse" = args[1]
382
383        self.check_environ(environ)
384        environ["wsgi.input"] = InputStream(environ["wsgi.input"])
385        environ["wsgi.errors"] = ErrorStream(environ["wsgi.errors"])
386
387        # Hook our own file wrapper in so that applications will always
388        # iterate to the end and we can check the content length.
389        environ["wsgi.file_wrapper"] = FileWrapper
390
391        headers_set: t.List[t.Any] = []
392        chunks: t.List[int] = []
393
394        def checking_start_response(
395            *args: t.Any, **kwargs: t.Any
396        ) -> t.Callable[[bytes], None]:
397            if len(args) not in {2, 3}:
398                warn(
399                    f"Invalid number of arguments: {len(args)}, expected 2 or 3.",
400                    WSGIWarning,
401                    stacklevel=2,
402                )
403
404            if kwargs:
405                warn("'start_response' does not take keyword arguments.", WSGIWarning)
406
407            status: str = args[0]
408            headers: t.List[t.Tuple[str, str]] = args[1]
409            exc_info: t.Optional[
410                t.Tuple[t.Type[BaseException], BaseException, TracebackType]
411            ] = (args[2] if len(args) == 3 else None)
412
413            headers_set[:] = self.check_start_response(status, headers, exc_info)
414            return GuardedWrite(start_response(status, headers, exc_info), chunks)
415
416        app_iter = self.app(environ, t.cast("StartResponse", checking_start_response))
417        self.check_iterator(app_iter)
418        return GuardedIterator(
419            app_iter, t.cast(t.Tuple[int, Headers], headers_set), chunks
420        )
421