1import pytest
2from flask import Flask
3from flask import jsonify
4from flask import request
5
6from flask_jwt_extended import create_access_token
7from flask_jwt_extended import create_refresh_token
8from flask_jwt_extended import jwt_required
9from flask_jwt_extended import JWTManager
10from flask_jwt_extended import set_access_cookies
11from flask_jwt_extended import set_refresh_cookies
12from flask_jwt_extended import unset_access_cookies
13from flask_jwt_extended import unset_jwt_cookies
14from flask_jwt_extended import unset_refresh_cookies
15
16
17def _get_cookie_from_response(response, cookie_name):
18    cookie_headers = response.headers.getlist("Set-Cookie")
19    for header in cookie_headers:
20        attributes = header.split(";")
21        if cookie_name in attributes[0]:
22            cookie = {}
23            for attr in attributes:
24                split = attr.split("=")
25                cookie[split[0].strip().lower()] = split[1] if len(split) > 1 else True
26            return cookie
27    return None
28
29
30@pytest.fixture(scope="function")
31def app():
32    app = Flask(__name__)
33    app.config["JWT_SECRET_KEY"] = "foobarbaz"
34    app.config["JWT_TOKEN_LOCATION"] = ["cookies"]
35    JWTManager(app)
36
37    @app.route("/access_token", methods=["GET"])
38    def access_token():
39        domain = request.args.get("domain")
40        resp = jsonify(login=True)
41        access_token = create_access_token("username")
42        set_access_cookies(resp, access_token, domain=domain)
43        return resp
44
45    @app.route("/refresh_token", methods=["GET"])
46    def refresh_token():
47        domain = request.args.get("domain")
48        resp = jsonify(login=True)
49        refresh_token = create_refresh_token("username")
50        set_refresh_cookies(resp, refresh_token, domain=domain)
51        return resp
52
53    @app.route("/delete_tokens", methods=["GET"])
54    def delete_tokens():
55        domain = request.args.get("domain")
56        resp = jsonify(logout=True)
57        unset_jwt_cookies(resp, domain=domain)
58        return resp
59
60    @app.route("/delete_access_tokens", methods=["GET"])
61    def delete_access_tokens():
62        domain = request.args.get("domain")
63        resp = jsonify(access_revoked=True)
64        unset_access_cookies(resp, domain=domain)
65        return resp
66
67    @app.route("/delete_refresh_tokens", methods=["GET"])
68    def delete_refresh_tokens():
69        domain = request.args.get("domain")
70        resp = jsonify(refresh_revoked=True)
71        unset_refresh_cookies(resp, domain=domain)
72        return resp
73
74    @app.route("/protected", methods=["GET"])
75    @jwt_required()
76    def protected():
77        return jsonify(foo="bar")
78
79    @app.route("/post_protected", methods=["POST"])
80    @jwt_required()
81    def post_protected():
82        return jsonify(foo="bar")
83
84    @app.route("/refresh_protected", methods=["GET"])
85    @jwt_required(refresh=True)
86    def refresh_protected():
87        return jsonify(foo="bar")
88
89    @app.route("/post_refresh_protected", methods=["POST"])
90    @jwt_required(refresh=True)
91    def post_refresh_protected():
92        return jsonify(foo="bar")
93
94    @app.route("/optional_post_protected", methods=["POST"])
95    @jwt_required(optional=True)
96    def optional_post_protected():
97        return jsonify(foo="bar")
98
99    return app
100
101
102@pytest.mark.parametrize(
103    "options",
104    [
105        (
106            "/refresh_token",
107            "refresh_token_cookie",
108            "/refresh_protected",
109            "/delete_refresh_tokens",
110        ),  # nopep8
111        ("/access_token", "access_token_cookie", "/protected", "/delete_access_tokens"),
112    ],
113)
114def test_jwt_refresh_required_with_cookies(app, options):
115    test_client = app.test_client()
116    auth_url, cookie_name, protected_url, delete_url = options
117
118    # Test without cookies
119    response = test_client.get(protected_url)
120    assert response.status_code == 401
121    assert response.get_json() == {"msg": 'Missing cookie "{}"'.format(cookie_name)}
122
123    # Test after receiving cookies
124    test_client.get(auth_url)
125    response = test_client.get(protected_url)
126    assert response.status_code == 200
127    assert response.get_json() == {"foo": "bar"}
128
129    # Test after issuing a 'logout' to delete the cookies
130    test_client.get(delete_url)
131    response = test_client.get(protected_url)
132    assert response.status_code == 401
133    assert response.get_json() == {"msg": 'Missing cookie "{}"'.format(cookie_name)}
134
135    # log back in once more to test that clearing all tokens works
136    test_client.get(auth_url)
137    response = test_client.get(protected_url)
138    assert response.status_code == 200
139
140    test_client.get("/delete_tokens")
141    response = test_client.get(protected_url)
142    assert response.status_code == 401
143    assert response.get_json() == {"msg": 'Missing cookie "{}"'.format(cookie_name)}
144
145
146@pytest.mark.parametrize(
147    "options",
148    [
149        ("/refresh_token", "csrf_refresh_token", "/post_refresh_protected"),
150        ("/access_token", "csrf_access_token", "/post_protected"),
151    ],
152)
153def test_default_access_csrf_protection(app, options):
154    test_client = app.test_client()
155    auth_url, csrf_cookie_name, post_url = options
156
157    # Get the jwt cookies and csrf double submit tokens
158    response = test_client.get(auth_url)
159    csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name]
160
161    # Test you cannot post without the additional csrf protection
162    response = test_client.post(post_url)
163    assert response.status_code == 401
164    assert response.get_json() == {"msg": "Missing CSRF token"}
165
166    # Test that you can post with the csrf double submit value
167    csrf_headers = {"X-CSRF-TOKEN": csrf_token}
168    response = test_client.post(post_url, headers=csrf_headers)
169    assert response.status_code == 200
170    assert response.get_json() == {"foo": "bar"}
171
172
173@pytest.mark.parametrize(
174    "options",
175    [
176        ("/refresh_token", "/post_refresh_protected"),
177        ("/access_token", "/post_protected"),
178    ],
179)
180def test_non_matching_csrf_token(app, options):
181    test_client = app.test_client()
182    auth_url, post_url = options
183
184    # Get the jwt cookies and csrf double submit tokens
185    test_client.get(auth_url)
186    csrf_headers = {"X-CSRF-TOKEN": "totally_wrong_token"}
187    response = test_client.post(post_url, headers=csrf_headers)
188    assert response.status_code == 401
189    assert response.get_json() == {"msg": "CSRF double submit tokens do not match"}
190
191
192@pytest.mark.parametrize(
193    "options",
194    [
195        ("/refresh_token", "/post_refresh_protected"),
196        ("/access_token", "/post_protected"),
197    ],
198)
199def test_csrf_disabled(app, options):
200    app.config["JWT_COOKIE_CSRF_PROTECT"] = False
201    test_client = app.test_client()
202    auth_url, post_url = options
203
204    # Get the jwt cookies and csrf double submit tokens
205    test_client.get(auth_url)
206    response = test_client.post(post_url)
207    assert response.status_code == 200
208    assert response.get_json() == {"foo": "bar"}
209
210
211@pytest.mark.parametrize(
212    "options",
213    [
214        ("/refresh_token", "csrf_refresh_token", "/post_refresh_protected"),
215        ("/access_token", "csrf_access_token", "/post_protected"),
216    ],
217)
218def test_csrf_with_custom_header_names(app, options):
219    app.config["JWT_ACCESS_CSRF_HEADER_NAME"] = "FOO"
220    app.config["JWT_REFRESH_CSRF_HEADER_NAME"] = "FOO"
221    test_client = app.test_client()
222    auth_url, csrf_cookie_name, post_url = options
223
224    # Get the jwt cookies and csrf double submit tokens
225    response = test_client.get(auth_url)
226    csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name]
227
228    # Test that you can post with the csrf double submit value
229    csrf_headers = {"FOO": csrf_token}
230    response = test_client.post(post_url, headers=csrf_headers)
231    assert response.status_code == 200
232    assert response.get_json() == {"foo": "bar"}
233
234
235@pytest.mark.parametrize(
236    "options",
237    [
238        ("/refresh_token", "csrf_refresh_token", "/post_refresh_protected"),
239        ("/access_token", "csrf_access_token", "/post_protected"),
240    ],
241)
242def test_csrf_with_default_form_field(app, options):
243    app.config["JWT_CSRF_CHECK_FORM"] = True
244    test_client = app.test_client()
245    auth_url, csrf_cookie_name, post_url = options
246
247    # Get the jwt cookies and csrf double submit tokens
248    response = test_client.get(auth_url)
249    csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name]
250
251    # Test that you can post with the csrf double submit value
252    csrf_data = {"csrf_token": csrf_token}
253    response = test_client.post(post_url, data=csrf_data)
254    assert response.status_code == 200
255    assert response.get_json() == {"foo": "bar"}
256
257
258@pytest.mark.parametrize(
259    "options",
260    [
261        ("/refresh_token", "csrf_refresh_token", "/post_refresh_protected"),
262        ("/access_token", "csrf_access_token", "/post_protected"),
263    ],
264)
265def test_csrf_with_custom_form_field(app, options):
266    app.config["JWT_CSRF_CHECK_FORM"] = True
267    app.config["JWT_ACCESS_CSRF_FIELD_NAME"] = "FOO"
268    app.config["JWT_REFRESH_CSRF_FIELD_NAME"] = "FOO"
269    test_client = app.test_client()
270    auth_url, csrf_cookie_name, post_url = options
271
272    # Get the jwt cookies and csrf double submit tokens
273    response = test_client.get(auth_url)
274    csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name]
275
276    # Test that you can post with the csrf double submit value
277    csrf_data = {"FOO": csrf_token}
278    response = test_client.post(post_url, data=csrf_data)
279    assert response.status_code == 200
280    assert response.get_json() == {"foo": "bar"}
281
282
283@pytest.mark.parametrize(
284    "options",
285    [
286        (
287            "/refresh_token",
288            "csrf_refresh_token",
289            "/refresh_protected",
290            "/post_refresh_protected",
291        ),  # nopep8
292        ("/access_token", "csrf_access_token", "/protected", "/post_protected"),
293    ],
294)
295def test_custom_csrf_methods(app, options):
296    app.config["JWT_CSRF_METHODS"] = ["GET"]
297    test_client = app.test_client()
298    auth_url, csrf_cookie_name, get_url, post_url = options
299
300    # Get the jwt cookies and csrf double submit tokens
301    response = test_client.get(auth_url)
302    csrf_token = _get_cookie_from_response(response, csrf_cookie_name)[csrf_cookie_name]
303
304    # Insure we can now do posts without csrf
305    response = test_client.post(post_url)
306    assert response.status_code == 200
307    assert response.get_json() == {"foo": "bar"}
308
309    # Insure GET requests now fail without csrf
310    response = test_client.get(get_url)
311    assert response.status_code == 401
312    assert response.get_json() == {"msg": "Missing CSRF token"}
313
314    # Insure GET requests now succeed with csrf
315    csrf_headers = {"X-CSRF-TOKEN": csrf_token}
316    response = test_client.get(get_url, headers=csrf_headers)
317    assert response.status_code == 200
318    assert response.get_json() == {"foo": "bar"}
319
320
321def test_default_cookie_options(app):
322    test_client = app.test_client()
323
324    # Test the default access cookies
325    response = test_client.get("/access_token")
326    cookies = response.headers.getlist("Set-Cookie")
327    assert len(cookies) == 2  # JWT and CSRF value
328
329    access_cookie = _get_cookie_from_response(response, "access_token_cookie")
330    assert access_cookie is not None
331    assert access_cookie["path"] == "/"
332    assert access_cookie["httponly"] is True
333    assert "samesite" not in access_cookie
334
335    access_csrf_cookie = _get_cookie_from_response(response, "csrf_access_token")
336    assert access_csrf_cookie is not None
337    assert access_csrf_cookie["path"] == "/"
338    assert "httponly" not in access_csrf_cookie
339    assert "samesite" not in access_csrf_cookie
340
341    # Test the default refresh cookies
342    response = test_client.get("/refresh_token")
343    cookies = response.headers.getlist("Set-Cookie")
344    assert len(cookies) == 2  # JWT and CSRF value
345
346    refresh_cookie = _get_cookie_from_response(response, "refresh_token_cookie")
347    assert refresh_cookie is not None
348    assert refresh_cookie["path"] == "/"
349    assert refresh_cookie["httponly"] is True
350    assert "samesite" not in refresh_cookie
351
352    refresh_csrf_cookie = _get_cookie_from_response(response, "csrf_refresh_token")
353    assert refresh_csrf_cookie is not None
354    assert refresh_csrf_cookie["path"] == "/"
355    assert "httponly" not in refresh_csrf_cookie
356    assert "samesite" not in refresh_csrf_cookie
357
358
359def test_custom_cookie_options(app):
360    test_client = app.test_client()
361
362    app.config["JWT_COOKIE_SECURE"] = True
363    app.config["JWT_COOKIE_DOMAIN"] = "test.com"
364    app.config["JWT_SESSION_COOKIE"] = False
365    app.config["JWT_COOKIE_SAMESITE"] = "Strict"
366
367    # Test access cookies with changed options
368    response = test_client.get("/access_token")
369    cookies = response.headers.getlist("Set-Cookie")
370    assert len(cookies) == 2  # JWT and CSRF value
371
372    access_cookie = _get_cookie_from_response(response, "access_token_cookie")
373    assert access_cookie is not None
374    assert access_cookie["domain"] == "test.com"
375    assert access_cookie["path"] == "/"
376    assert access_cookie["expires"] != ""
377    assert access_cookie["httponly"] is True
378    assert access_cookie["secure"] is True
379    assert access_cookie["samesite"] == "Strict"
380
381    access_csrf_cookie = _get_cookie_from_response(response, "csrf_access_token")
382    assert access_csrf_cookie is not None
383    assert access_csrf_cookie["path"] == "/"
384    assert access_csrf_cookie["secure"] is True
385    assert access_csrf_cookie["domain"] == "test.com"
386    assert access_csrf_cookie["expires"] != ""
387    assert access_csrf_cookie["samesite"] == "Strict"
388
389    # Test refresh cookies with changed options
390    response = test_client.get("/refresh_token")
391    cookies = response.headers.getlist("Set-Cookie")
392    assert len(cookies) == 2  # JWT and CSRF value
393
394    refresh_cookie = _get_cookie_from_response(response, "refresh_token_cookie")
395    assert refresh_cookie is not None
396    assert refresh_cookie["domain"] == "test.com"
397    assert refresh_cookie["path"] == "/"
398    assert refresh_cookie["httponly"] is True
399    assert refresh_cookie["secure"] is True
400    assert refresh_cookie["expires"] != ""
401    assert refresh_cookie["samesite"] == "Strict"
402
403    refresh_csrf_cookie = _get_cookie_from_response(response, "csrf_refresh_token")
404    assert refresh_csrf_cookie is not None
405    assert refresh_csrf_cookie["path"] == "/"
406    assert refresh_csrf_cookie["secure"] is True
407    assert refresh_csrf_cookie["domain"] == "test.com"
408    assert refresh_csrf_cookie["expires"] != ""
409    assert refresh_csrf_cookie["samesite"] == "Strict"
410
411
412def test_custom_cookie_names_and_paths(app):
413    test_client = app.test_client()
414
415    app.config["JWT_ACCESS_CSRF_COOKIE_NAME"] = "access_foo_csrf"
416    app.config["JWT_REFRESH_CSRF_COOKIE_NAME"] = "refresh_foo_csrf"
417    app.config["JWT_ACCESS_CSRF_COOKIE_PATH"] = "/protected"
418    app.config["JWT_REFRESH_CSRF_COOKIE_PATH"] = "/refresh_protected"
419    app.config["JWT_ACCESS_COOKIE_NAME"] = "access_foo"
420    app.config["JWT_REFRESH_COOKIE_NAME"] = "refresh_foo"
421    app.config["JWT_ACCESS_COOKIE_PATH"] = "/protected"
422    app.config["JWT_REFRESH_COOKIE_PATH"] = "/refresh_protected"
423
424    # Test the default access cookies
425    response = test_client.get("/access_token")
426    cookies = response.headers.getlist("Set-Cookie")
427    assert len(cookies) == 2  # JWT and CSRF value
428
429    access_cookie = _get_cookie_from_response(response, "access_foo")
430    access_csrf_cookie = _get_cookie_from_response(response, "access_foo_csrf")
431    assert access_cookie is not None
432    assert access_csrf_cookie is not None
433    assert access_cookie["path"] == "/protected"
434    assert access_csrf_cookie["path"] == "/protected"
435
436    # Test the default refresh cookies
437    response = test_client.get("/refresh_token")
438    cookies = response.headers.getlist("Set-Cookie")
439    assert len(cookies) == 2  # JWT and CSRF value
440
441    refresh_cookie = _get_cookie_from_response(response, "refresh_foo")
442    refresh_csrf_cookie = _get_cookie_from_response(response, "refresh_foo_csrf")
443    assert refresh_cookie is not None
444    assert refresh_csrf_cookie is not None
445    assert refresh_cookie["path"] == "/refresh_protected"
446    assert refresh_csrf_cookie["path"] == "/refresh_protected"
447
448
449def test_csrf_token_not_in_cookie(app):
450    test_client = app.test_client()
451
452    app.config["JWT_CSRF_IN_COOKIES"] = False
453
454    # Test the default access cookies
455    response = test_client.get("/access_token")
456    cookies = response.headers.getlist("Set-Cookie")
457    assert len(cookies) == 1
458    access_cookie = _get_cookie_from_response(response, "access_token_cookie")
459    assert access_cookie is not None
460
461    # Test the default refresh cookies
462    response = test_client.get("/refresh_token")
463    cookies = response.headers.getlist("Set-Cookie")
464    assert len(cookies) == 1
465    refresh_cookie = _get_cookie_from_response(response, "refresh_token_cookie")
466    assert refresh_cookie is not None
467
468
469def test_cookies_without_csrf(app):
470    test_client = app.test_client()
471
472    app.config["JWT_COOKIE_CSRF_PROTECT"] = False
473
474    # Test the default access cookies
475    response = test_client.get("/access_token")
476    cookies = response.headers.getlist("Set-Cookie")
477    assert len(cookies) == 1
478    access_cookie = _get_cookie_from_response(response, "access_token_cookie")
479    assert access_cookie is not None
480
481    # Test the default refresh cookies
482    response = test_client.get("/refresh_token")
483    cookies = response.headers.getlist("Set-Cookie")
484    assert len(cookies) == 1
485    refresh_cookie = _get_cookie_from_response(response, "refresh_token_cookie")
486    assert refresh_cookie is not None
487
488
489def test_jwt_optional_with_csrf_enabled(app):
490    test_client = app.test_client()
491
492    # User without a token should be able to reach the endpoint without
493    # getting a CSRF error
494    response = test_client.post("/optional_post_protected")
495    assert response.status_code == 200
496    assert response.get_json() == {"foo": "bar"}
497
498    # User with a token should still get a CSRF error if csrf not present
499    response = test_client.get("/access_token")
500    response = test_client.post("/optional_post_protected")
501    assert response.status_code == 401
502    assert response.get_json() == {"msg": "Missing CSRF token"}
503
504
505@pytest.mark.parametrize(
506    "options",
507    [
508        (
509            "/access_token",
510            "/delete_access_tokens",
511            "access_token_cookie",
512            "csrf_access_token",
513        ),
514        (
515            "/refresh_token",
516            "/delete_refresh_tokens",
517            "refresh_token_cookie",
518            "csrf_refresh_token",
519        ),
520    ],
521)
522def test_override_domain_option(app, options):
523    auth_url, delete_url, auth_cookie_name, csrf_cookie_name = options
524    domain = "yolo.com"
525
526    test_client = app.test_client()
527    app.config["JWT_COOKIE_DOMAIN"] = "test.com"
528
529    # Test set access cookies with custom domain
530    response = test_client.get(f"{auth_url}?domain={domain}")
531    cookies = response.headers.getlist("Set-Cookie")
532    assert len(cookies) == 2  # JWT and CSRF value
533
534    access_cookie = _get_cookie_from_response(response, auth_cookie_name)
535    assert access_cookie is not None
536    assert access_cookie["domain"] == domain
537
538    access_csrf_cookie = _get_cookie_from_response(response, csrf_cookie_name)
539    assert access_csrf_cookie is not None
540    assert access_csrf_cookie["domain"] == domain
541
542    # Test unset access cookies with custom domain
543    response = test_client.get(f"{delete_url}?domain={domain}")
544    cookies = response.headers.getlist("Set-Cookie")
545    assert len(cookies) == 2  # JWT and CSRF value
546
547    access_cookie = _get_cookie_from_response(response, auth_cookie_name)
548    assert access_cookie is not None
549    assert access_cookie["domain"] == domain
550
551    access_csrf_cookie = _get_cookie_from_response(response, csrf_cookie_name)
552    assert access_csrf_cookie is not None
553    assert access_csrf_cookie["domain"] == domain
554