1import pytest
2
3import flask
4from flask.sessions import SecureCookieSessionInterface
5from flask.sessions import SessionInterface
6
7try:
8    from greenlet import greenlet
9except ImportError:
10    greenlet = None
11
12
13def test_teardown_on_pop(app):
14    buffer = []
15
16    @app.teardown_request
17    def end_of_request(exception):
18        buffer.append(exception)
19
20    ctx = app.test_request_context()
21    ctx.push()
22    assert buffer == []
23    ctx.pop()
24    assert buffer == [None]
25
26
27def test_teardown_with_previous_exception(app):
28    buffer = []
29
30    @app.teardown_request
31    def end_of_request(exception):
32        buffer.append(exception)
33
34    try:
35        raise Exception("dummy")
36    except Exception:
37        pass
38
39    with app.test_request_context():
40        assert buffer == []
41    assert buffer == [None]
42
43
44def test_teardown_with_handled_exception(app):
45    buffer = []
46
47    @app.teardown_request
48    def end_of_request(exception):
49        buffer.append(exception)
50
51    with app.test_request_context():
52        assert buffer == []
53        try:
54            raise Exception("dummy")
55        except Exception:
56            pass
57    assert buffer == [None]
58
59
60def test_proper_test_request_context(app):
61    app.config.update(SERVER_NAME="localhost.localdomain:5000")
62
63    @app.route("/")
64    def index():
65        return None
66
67    @app.route("/", subdomain="foo")
68    def sub():
69        return None
70
71    with app.test_request_context("/"):
72        assert (
73            flask.url_for("index", _external=True)
74            == "http://localhost.localdomain:5000/"
75        )
76
77    with app.test_request_context("/"):
78        assert (
79            flask.url_for("sub", _external=True)
80            == "http://foo.localhost.localdomain:5000/"
81        )
82
83    # suppress Werkzeug 0.15 warning about name mismatch
84    with pytest.warns(None):
85        with app.test_request_context(
86            "/", environ_overrides={"HTTP_HOST": "localhost"}
87        ):
88            pass
89
90    app.config.update(SERVER_NAME="localhost")
91    with app.test_request_context("/", environ_overrides={"SERVER_NAME": "localhost"}):
92        pass
93
94    app.config.update(SERVER_NAME="localhost:80")
95    with app.test_request_context(
96        "/", environ_overrides={"SERVER_NAME": "localhost:80"}
97    ):
98        pass
99
100
101def test_context_binding(app):
102    @app.route("/")
103    def index():
104        return f"Hello {flask.request.args['name']}!"
105
106    @app.route("/meh")
107    def meh():
108        return flask.request.url
109
110    with app.test_request_context("/?name=World"):
111        assert index() == "Hello World!"
112    with app.test_request_context("/meh"):
113        assert meh() == "http://localhost/meh"
114    assert flask._request_ctx_stack.top is None
115
116
117def test_context_test(app):
118    assert not flask.request
119    assert not flask.has_request_context()
120    ctx = app.test_request_context()
121    ctx.push()
122    try:
123        assert flask.request
124        assert flask.has_request_context()
125    finally:
126        ctx.pop()
127
128
129def test_manual_context_binding(app):
130    @app.route("/")
131    def index():
132        return f"Hello {flask.request.args['name']}!"
133
134    ctx = app.test_request_context("/?name=World")
135    ctx.push()
136    assert index() == "Hello World!"
137    ctx.pop()
138    with pytest.raises(RuntimeError):
139        index()
140
141
142@pytest.mark.skipif(greenlet is None, reason="greenlet not installed")
143class TestGreenletContextCopying:
144    def test_greenlet_context_copying(self, app, client):
145        greenlets = []
146
147        @app.route("/")
148        def index():
149            flask.session["fizz"] = "buzz"
150            reqctx = flask._request_ctx_stack.top.copy()
151
152            def g():
153                assert not flask.request
154                assert not flask.current_app
155                with reqctx:
156                    assert flask.request
157                    assert flask.current_app == app
158                    assert flask.request.path == "/"
159                    assert flask.request.args["foo"] == "bar"
160                    assert flask.session.get("fizz") == "buzz"
161                assert not flask.request
162                return 42
163
164            greenlets.append(greenlet(g))
165            return "Hello World!"
166
167        rv = client.get("/?foo=bar")
168        assert rv.data == b"Hello World!"
169
170        result = greenlets[0].run()
171        assert result == 42
172
173    def test_greenlet_context_copying_api(self, app, client):
174        greenlets = []
175
176        @app.route("/")
177        def index():
178            flask.session["fizz"] = "buzz"
179
180            @flask.copy_current_request_context
181            def g():
182                assert flask.request
183                assert flask.current_app == app
184                assert flask.request.path == "/"
185                assert flask.request.args["foo"] == "bar"
186                assert flask.session.get("fizz") == "buzz"
187                return 42
188
189            greenlets.append(greenlet(g))
190            return "Hello World!"
191
192        rv = client.get("/?foo=bar")
193        assert rv.data == b"Hello World!"
194
195        result = greenlets[0].run()
196        assert result == 42
197
198
199def test_session_error_pops_context():
200    class SessionError(Exception):
201        pass
202
203    class FailingSessionInterface(SessionInterface):
204        def open_session(self, app, request):
205            raise SessionError()
206
207    class CustomFlask(flask.Flask):
208        session_interface = FailingSessionInterface()
209
210    app = CustomFlask(__name__)
211
212    @app.route("/")
213    def index():
214        # shouldn't get here
215        AssertionError()
216
217    response = app.test_client().get("/")
218    assert response.status_code == 500
219    assert not flask.request
220    assert not flask.current_app
221
222
223def test_session_dynamic_cookie_name():
224
225    # This session interface will use a cookie with a different name if the
226    # requested url ends with the string "dynamic_cookie"
227    class PathAwareSessionInterface(SecureCookieSessionInterface):
228        def get_cookie_name(self, app):
229            if flask.request.url.endswith("dynamic_cookie"):
230                return "dynamic_cookie_name"
231            else:
232                return super().get_cookie_name(app)
233
234    class CustomFlask(flask.Flask):
235        session_interface = PathAwareSessionInterface()
236
237    app = CustomFlask(__name__)
238    app.secret_key = "secret_key"
239
240    @app.route("/set", methods=["POST"])
241    def set():
242        flask.session["value"] = flask.request.form["value"]
243        return "value set"
244
245    @app.route("/get")
246    def get():
247        v = flask.session.get("value", "None")
248        return v
249
250    @app.route("/set_dynamic_cookie", methods=["POST"])
251    def set_dynamic_cookie():
252        flask.session["value"] = flask.request.form["value"]
253        return "value set"
254
255    @app.route("/get_dynamic_cookie")
256    def get_dynamic_cookie():
257        v = flask.session.get("value", "None")
258        return v
259
260    test_client = app.test_client()
261
262    # first set the cookie in both /set urls but each with a different value
263    assert test_client.post("/set", data={"value": "42"}).data == b"value set"
264    assert (
265        test_client.post("/set_dynamic_cookie", data={"value": "616"}).data
266        == b"value set"
267    )
268
269    # now check that the relevant values come back - meaning that different
270    # cookies are being used for the urls that end with "dynamic cookie"
271    assert test_client.get("/get").data == b"42"
272    assert test_client.get("/get_dynamic_cookie").data == b"616"
273
274
275def test_bad_environ_raises_bad_request():
276    app = flask.Flask(__name__)
277
278    from flask.testing import EnvironBuilder
279
280    builder = EnvironBuilder(app)
281    environ = builder.get_environ()
282
283    # use a non-printable character in the Host - this is key to this test
284    environ["HTTP_HOST"] = "\x8a"
285
286    with app.request_context(environ):
287        response = app.full_dispatch_request()
288    assert response.status_code == 400
289
290
291def test_environ_for_valid_idna_completes():
292    app = flask.Flask(__name__)
293
294    @app.route("/")
295    def index():
296        return "Hello World!"
297
298    from flask.testing import EnvironBuilder
299
300    builder = EnvironBuilder(app)
301    environ = builder.get_environ()
302
303    # these characters are all IDNA-compatible
304    environ["HTTP_HOST"] = "ąśźäüжŠßя.com"
305
306    with app.request_context(environ):
307        response = app.full_dispatch_request()
308
309    assert response.status_code == 200
310
311
312def test_normal_environ_completes():
313    app = flask.Flask(__name__)
314
315    @app.route("/")
316    def index():
317        return "Hello World!"
318
319    response = app.test_client().get("/", headers={"host": "xn--on-0ia.com"})
320    assert response.status_code == 200
321