1import pytest
2from flask import Flask
3from flask import jsonify
4
5from flask_jwt_extended import create_access_token
6from flask_jwt_extended import get_jwt_request_location
7from flask_jwt_extended import jwt_required
8from flask_jwt_extended import JWTManager
9from flask_jwt_extended import set_access_cookies
10
11
12@pytest.fixture(scope="function")
13def app():
14    app = Flask(__name__)
15    app.config["JWT_SECRET_KEY"] = "foobarbaz"
16    app.config["JWT_TOKEN_LOCATION"] = ["headers", "cookies", "query_string", "json"]
17    JWTManager(app)
18
19    @app.route("/cookie_login", methods=["GET"])
20    def cookie_login():
21        resp = jsonify(login=True)
22        access_token = create_access_token("username")
23        set_access_cookies(resp, access_token)
24        return resp
25
26    @app.route("/protected", methods=["GET", "POST"])
27    @jwt_required()
28    def access_protected():
29        return jsonify(foo="bar", location=get_jwt_request_location())
30
31    return app
32
33
34@pytest.fixture(scope="function")
35def app_with_locations():
36    app = Flask(__name__)
37    app.config["JWT_SECRET_KEY"] = "foobarbaz"
38    app.config["JWT_TOKEN_LOCATION"] = ["headers"]
39    locations = ["headers", "cookies", "query_string", "json"]
40    JWTManager(app)
41
42    @app.route("/cookie_login", methods=["GET"])
43    def cookie_login():
44        resp = jsonify(login=True)
45        access_token = create_access_token("username")
46        set_access_cookies(resp, access_token)
47        return resp
48
49    @app.route("/protected", methods=["GET", "POST"])
50    @jwt_required(locations=locations)
51    def access_protected():
52        return jsonify(foo="bar", location=get_jwt_request_location())
53
54    return app
55
56
57def test_header_access(app, app_with_locations):
58    for app in (app, app_with_locations):
59        test_client = app.test_client()
60        with app.test_request_context():
61            access_token = create_access_token("username")
62
63        access_headers = {"Authorization": "Bearer {}".format(access_token)}
64        response = test_client.get("/protected", headers=access_headers)
65        assert response.status_code == 200
66        assert response.get_json() == {"foo": "bar", "location": "headers"}
67
68
69def test_cookie_access(app, app_with_locations):
70    for app in (app, app_with_locations):
71        test_client = app.test_client()
72        test_client.get("/cookie_login")
73        response = test_client.get("/protected")
74        assert response.status_code == 200
75        assert response.get_json() == {"foo": "bar", "location": "cookies"}
76
77
78def test_query_string_access(app, app_with_locations):
79    for app in (app, app_with_locations):
80        test_client = app.test_client()
81        with app.test_request_context():
82            access_token = create_access_token("username")
83
84        url = "/protected?jwt={}".format(access_token)
85        response = test_client.get(url)
86        assert response.status_code == 200
87        assert response.get_json() == {"foo": "bar", "location": "query_string"}
88
89
90def test_json_access(app, app_with_locations):
91    for app in (app, app_with_locations):
92        test_client = app.test_client()
93        with app.test_request_context():
94            access_token = create_access_token("username")
95        data = {"access_token": access_token}
96        response = test_client.post("/protected", json=data)
97        assert response.status_code == 200
98        assert response.get_json() == {"foo": "bar", "location": "json"}
99
100
101@pytest.mark.parametrize(
102    "options",
103    [
104        (
105            ["cookies", "headers"],
106            (
107                "Missing JWT in cookies or headers (Missing cookie "
108                '"access_token_cookie"; Missing Authorization Header)'
109            ),
110        ),
111        (
112            ["json", "query_string"],
113            (
114                "Missing JWT in json or query_string (Invalid "
115                "content-type. Must be application/json.; "
116                "Missing 'jwt' query paramater)"
117            ),
118        ),
119    ],
120)
121def test_no_jwt_in_request(app, options):
122    token_locations, expected_err = options
123    app.config["JWT_TOKEN_LOCATION"] = token_locations
124    test_client = app.test_client()
125    response = test_client.get("/protected")
126    assert response.status_code == 401
127    assert response.get_json() == {"msg": expected_err}
128
129
130@pytest.mark.parametrize(
131    "options",
132    [
133        (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
134        (["headers", "cookies"], 200, None, {"foo": "bar", "location": "cookies"}),
135    ],
136)
137def test_order_of_jwt_locations_in_request(app, options):
138    """ test order doesn't matter if at least one valid token is set"""
139    token_locations, status_code, expected_err, expected_dict = options
140    app.config["JWT_TOKEN_LOCATION"] = token_locations
141    test_client = app.test_client()
142    test_client.get("/cookie_login")
143    response = test_client.get("/protected")
144
145    assert response.status_code == status_code
146    if expected_dict:
147        assert response.get_json() == expected_dict
148    else:
149        assert response.get_json() == {"msg": expected_err}
150
151
152@pytest.mark.parametrize(
153    "options",
154    [
155        (["cookies", "headers"], 200, None, {"foo": "bar", "location": "cookies"}),
156        (["headers", "cookies"], 422, ("Invalid header padding"), None),
157    ],
158)
159def test_order_of_jwt_locations_with_one_invalid_token_in_request(app, options):
160    """ test order doesn't matter if at least one valid token is set"""
161    token_locations, status_code, expected_err, expected_dict = options
162    app.config["JWT_TOKEN_LOCATION"] = token_locations
163    test_client = app.test_client()
164
165    with app.test_request_context():
166        access_token = create_access_token("username")
167    # invalidate the token, to check token location precedence
168    access_token = "000000{}".format(access_token[5:])
169    access_headers = {"Authorization": "Bearer {}".format(access_token)}
170    # set valid cookies
171    test_client.get("/cookie_login")
172    response = test_client.get("/protected", headers=access_headers)
173
174    assert response.status_code == status_code
175    if expected_dict:
176        assert response.get_json() == expected_dict
177    else:
178        assert response.get_json() == {"msg": expected_err}
179