1import collections
2import re
3from traceback import format_exception
4from unittest import mock
5
6import pytest
7
8from aiohttp import helpers, signals, web
9from aiohttp.test_utils import make_mocked_request
10
11
12@pytest.fixture
13def buf():
14    return bytearray()
15
16
17@pytest.fixture
18def http_request(buf):
19    method = "GET"
20    path = "/"
21    writer = mock.Mock()
22    writer.drain.return_value = ()
23
24    def append(data=b""):
25        buf.extend(data)
26        return helpers.noop()
27
28    async def write_headers(status_line, headers):
29        headers = (
30            status_line
31            + "\r\n"
32            + "".join([k + ": " + v + "\r\n" for k, v in headers.items()])
33        )
34        headers = headers.encode("utf-8") + b"\r\n"
35        buf.extend(headers)
36
37    writer.buffer_data.side_effect = append
38    writer.write.side_effect = append
39    writer.write_eof.side_effect = append
40    writer.write_headers.side_effect = write_headers
41
42    app = mock.Mock()
43    app._debug = False
44    app.on_response_prepare = signals.Signal(app)
45    app.on_response_prepare.freeze()
46    req = make_mocked_request(method, path, app=app, writer=writer)
47    return req
48
49
50def test_all_http_exceptions_exported() -> None:
51    assert "HTTPException" in web.__all__
52    for name in dir(web):
53        if name.startswith("_"):
54            continue
55        obj = getattr(web, name)
56        if isinstance(obj, type) and issubclass(obj, web.HTTPException):
57            assert name in web.__all__
58
59
60async def test_HTTPOk(buf, http_request) -> None:
61    resp = web.HTTPOk()
62    await resp.prepare(http_request)
63    await resp.write_eof()
64    txt = buf.decode("utf8")
65    assert re.match(
66        (
67            "HTTP/1.1 200 OK\r\n"
68            "Content-Type: text/plain; charset=utf-8\r\n"
69            "Content-Length: 7\r\n"
70            "Date: .+\r\n"
71            "Server: .+\r\n\r\n"
72            "200: OK"
73        ),
74        txt,
75    )
76
77
78def test_terminal_classes_has_status_code() -> None:
79    terminals = set()
80    for name in dir(web):
81        obj = getattr(web, name)
82        if isinstance(obj, type) and issubclass(obj, web.HTTPException):
83            terminals.add(obj)
84
85    dup = frozenset(terminals)
86    for cls1 in dup:
87        for cls2 in dup:
88            if cls1 in cls2.__bases__:
89                terminals.discard(cls1)
90
91    for cls in terminals:
92        assert cls.status_code is not None
93    codes = collections.Counter(cls.status_code for cls in terminals)
94    assert None not in codes
95    assert 1 == codes.most_common(1)[0][1]
96
97
98async def test_HTTPFound(buf, http_request) -> None:
99    resp = web.HTTPFound(location="/redirect")
100    assert "/redirect" == resp.location
101    assert "/redirect" == resp.headers["location"]
102    await resp.prepare(http_request)
103    await resp.write_eof()
104    txt = buf.decode("utf8")
105    assert re.match(
106        "HTTP/1.1 302 Found\r\n"
107        "Content-Type: text/plain; charset=utf-8\r\n"
108        "Location: /redirect\r\n"
109        "Content-Length: 10\r\n"
110        "Date: .+\r\n"
111        "Server: .+\r\n\r\n"
112        "302: Found",
113        txt,
114    )
115
116
117def test_HTTPFound_empty_location() -> None:
118    with pytest.raises(ValueError):
119        web.HTTPFound(location="")
120
121    with pytest.raises(ValueError):
122        web.HTTPFound(location=None)
123
124
125def test_HTTPFound_location_CRLF() -> None:
126    exc = web.HTTPFound(location="/redirect\r\n")
127    assert "\r\n" not in exc.headers["Location"]
128
129
130async def test_HTTPMethodNotAllowed(buf, http_request) -> None:
131    resp = web.HTTPMethodNotAllowed("get", ["POST", "PUT"])
132    assert "GET" == resp.method
133    assert {"POST", "PUT"} == resp.allowed_methods
134    assert "POST,PUT" == resp.headers["allow"]
135    await resp.prepare(http_request)
136    await resp.write_eof()
137    txt = buf.decode("utf8")
138    assert re.match(
139        "HTTP/1.1 405 Method Not Allowed\r\n"
140        "Content-Type: text/plain; charset=utf-8\r\n"
141        "Allow: POST,PUT\r\n"
142        "Content-Length: 23\r\n"
143        "Date: .+\r\n"
144        "Server: .+\r\n\r\n"
145        "405: Method Not Allowed",
146        txt,
147    )
148
149
150def test_override_body_with_text() -> None:
151    resp = web.HTTPNotFound(text="Page not found")
152    assert 404 == resp.status
153    assert b"Page not found" == resp.body
154    assert "Page not found" == resp.text
155    assert "text/plain" == resp.content_type
156    assert "utf-8" == resp.charset
157
158
159def test_override_body_with_binary() -> None:
160    txt = "<html><body>Page not found</body></html>"
161    with pytest.warns(DeprecationWarning):
162        resp = web.HTTPNotFound(body=txt.encode("utf-8"), content_type="text/html")
163    assert 404 == resp.status
164    assert txt.encode("utf-8") == resp.body
165    assert txt == resp.text
166    assert "text/html" == resp.content_type
167    assert resp.charset is None
168
169
170def test_default_body() -> None:
171    resp = web.HTTPOk()
172    assert b"200: OK" == resp.body
173
174
175def test_empty_body_204() -> None:
176    resp = web.HTTPNoContent()
177    assert resp.body is None
178
179
180def test_empty_body_205() -> None:
181    resp = web.HTTPNoContent()
182    assert resp.body is None
183
184
185def test_empty_body_304() -> None:
186    resp = web.HTTPNoContent()
187    resp.body is None
188
189
190def test_link_header_451(buf) -> None:
191    resp = web.HTTPUnavailableForLegalReasons(link="http://warning.or.kr/")
192
193    assert "http://warning.or.kr/" == resp.link
194    assert '<http://warning.or.kr/>; rel="blocked-by"' == resp.headers["Link"]
195
196
197def test_HTTPException_retains_cause() -> None:
198    with pytest.raises(web.HTTPException) as ei:
199        try:
200            raise Exception("CustomException")
201        except Exception as exc:
202            raise web.HTTPException() from exc
203    tb = "".join(format_exception(ei.type, ei.value, ei.tb))
204    assert "CustomException" in tb
205    assert "direct cause" in tb
206