1# Copyright 2014 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import base64
16import datetime
17import json
18import os
19
20import mock
21import pytest
22
23from google.auth import _helpers
24from google.auth import crypt
25from google.auth import exceptions
26from google.auth import jwt
27
28
29DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
30
31with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
32    PRIVATE_KEY_BYTES = fh.read()
33
34with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh:
35    PUBLIC_CERT_BYTES = fh.read()
36
37with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh:
38    OTHER_CERT_BYTES = fh.read()
39
40with open(os.path.join(DATA_DIR, "es256_privatekey.pem"), "rb") as fh:
41    EC_PRIVATE_KEY_BYTES = fh.read()
42
43with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh:
44    EC_PUBLIC_CERT_BYTES = fh.read()
45
46SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")
47
48with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
49    SERVICE_ACCOUNT_INFO = json.load(fh)
50
51
52@pytest.fixture
53def signer():
54    return crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
55
56
57def test_encode_basic(signer):
58    test_payload = {"test": "value"}
59    encoded = jwt.encode(signer, test_payload)
60    header, payload, _, _ = jwt._unverified_decode(encoded)
61    assert payload == test_payload
62    assert header == {"typ": "JWT", "alg": "RS256", "kid": signer.key_id}
63
64
65def test_encode_extra_headers(signer):
66    encoded = jwt.encode(signer, {}, header={"extra": "value"})
67    header = jwt.decode_header(encoded)
68    assert header == {
69        "typ": "JWT",
70        "alg": "RS256",
71        "kid": signer.key_id,
72        "extra": "value",
73    }
74
75
76def test_encode_custom_alg_in_headers(signer):
77    encoded = jwt.encode(signer, {}, header={"alg": "foo"})
78    header = jwt.decode_header(encoded)
79    assert header == {"typ": "JWT", "alg": "foo", "kid": signer.key_id}
80
81
82@pytest.fixture
83def es256_signer():
84    return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1")
85
86
87def test_encode_basic_es256(es256_signer):
88    test_payload = {"test": "value"}
89    encoded = jwt.encode(es256_signer, test_payload)
90    header, payload, _, _ = jwt._unverified_decode(encoded)
91    assert payload == test_payload
92    assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id}
93
94
95@pytest.fixture
96def token_factory(signer, es256_signer):
97    def factory(claims=None, key_id=None, use_es256_signer=False):
98        now = _helpers.datetime_to_secs(_helpers.utcnow())
99        payload = {
100            "aud": "audience@example.com",
101            "iat": now,
102            "exp": now + 300,
103            "user": "billy bob",
104            "metadata": {"meta": "data"},
105        }
106        payload.update(claims or {})
107
108        # False is specified to remove the signer's key id for testing
109        # headers without key ids.
110        if key_id is False:
111            signer._key_id = None
112            key_id = None
113
114        if use_es256_signer:
115            return jwt.encode(es256_signer, payload, key_id=key_id)
116        else:
117            return jwt.encode(signer, payload, key_id=key_id)
118
119    return factory
120
121
122def test_decode_valid(token_factory):
123    payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES)
124    assert payload["aud"] == "audience@example.com"
125    assert payload["user"] == "billy bob"
126    assert payload["metadata"]["meta"] == "data"
127
128
129def test_decode_valid_es256(token_factory):
130    payload = jwt.decode(
131        token_factory(use_es256_signer=True), certs=EC_PUBLIC_CERT_BYTES
132    )
133    assert payload["aud"] == "audience@example.com"
134    assert payload["user"] == "billy bob"
135    assert payload["metadata"]["meta"] == "data"
136
137
138def test_decode_valid_with_audience(token_factory):
139    payload = jwt.decode(
140        token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com"
141    )
142    assert payload["aud"] == "audience@example.com"
143    assert payload["user"] == "billy bob"
144    assert payload["metadata"]["meta"] == "data"
145
146
147def test_decode_valid_with_audience_list(token_factory):
148    payload = jwt.decode(
149        token_factory(),
150        certs=PUBLIC_CERT_BYTES,
151        audience=["audience@example.com", "another_audience@example.com"],
152    )
153    assert payload["aud"] == "audience@example.com"
154    assert payload["user"] == "billy bob"
155    assert payload["metadata"]["meta"] == "data"
156
157
158def test_decode_valid_unverified(token_factory):
159    payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False)
160    assert payload["aud"] == "audience@example.com"
161    assert payload["user"] == "billy bob"
162    assert payload["metadata"]["meta"] == "data"
163
164
165def test_decode_bad_token_wrong_number_of_segments():
166    with pytest.raises(ValueError) as excinfo:
167        jwt.decode("1.2", PUBLIC_CERT_BYTES)
168    assert excinfo.match(r"Wrong number of segments")
169
170
171def test_decode_bad_token_not_base64():
172    with pytest.raises((ValueError, TypeError)) as excinfo:
173        jwt.decode("1.2.3", PUBLIC_CERT_BYTES)
174    assert excinfo.match(r"Incorrect padding|more than a multiple of 4")
175
176
177def test_decode_bad_token_not_json():
178    token = b".".join([base64.urlsafe_b64encode(b"123!")] * 3)
179    with pytest.raises(ValueError) as excinfo:
180        jwt.decode(token, PUBLIC_CERT_BYTES)
181    assert excinfo.match(r"Can\'t parse segment")
182
183
184def test_decode_bad_token_no_iat_or_exp(signer):
185    token = jwt.encode(signer, {"test": "value"})
186    with pytest.raises(ValueError) as excinfo:
187        jwt.decode(token, PUBLIC_CERT_BYTES)
188    assert excinfo.match(r"Token does not contain required claim")
189
190
191def test_decode_bad_token_too_early(token_factory):
192    token = token_factory(
193        claims={
194            "iat": _helpers.datetime_to_secs(
195                _helpers.utcnow() + datetime.timedelta(hours=1)
196            )
197        }
198    )
199    with pytest.raises(ValueError) as excinfo:
200        jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=59)
201    assert excinfo.match(r"Token used too early")
202
203
204def test_decode_bad_token_expired(token_factory):
205    token = token_factory(
206        claims={
207            "exp": _helpers.datetime_to_secs(
208                _helpers.utcnow() - datetime.timedelta(hours=1)
209            )
210        }
211    )
212    with pytest.raises(ValueError) as excinfo:
213        jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=59)
214    assert excinfo.match(r"Token expired")
215
216
217def test_decode_success_with_no_clock_skew(token_factory):
218    token = token_factory(
219        claims={
220            "exp": _helpers.datetime_to_secs(
221                _helpers.utcnow() + datetime.timedelta(seconds=1)
222            ),
223            "iat": _helpers.datetime_to_secs(
224                _helpers.utcnow() - datetime.timedelta(seconds=1)
225            ),
226        }
227    )
228
229    jwt.decode(token, PUBLIC_CERT_BYTES)
230
231
232def test_decode_success_with_custom_clock_skew(token_factory):
233    token = token_factory(
234        claims={
235            "exp": _helpers.datetime_to_secs(
236                _helpers.utcnow() + datetime.timedelta(seconds=2)
237            ),
238            "iat": _helpers.datetime_to_secs(
239                _helpers.utcnow() - datetime.timedelta(seconds=2)
240            ),
241        }
242    )
243
244    jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=1)
245
246
247def test_decode_bad_token_wrong_audience(token_factory):
248    token = token_factory()
249    audience = "audience2@example.com"
250    with pytest.raises(ValueError) as excinfo:
251        jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience)
252    assert excinfo.match(r"Token has wrong audience")
253
254
255def test_decode_bad_token_wrong_audience_list(token_factory):
256    token = token_factory()
257    audience = ["audience2@example.com", "audience3@example.com"]
258    with pytest.raises(ValueError) as excinfo:
259        jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience)
260    assert excinfo.match(r"Token has wrong audience")
261
262
263def test_decode_wrong_cert(token_factory):
264    with pytest.raises(ValueError) as excinfo:
265        jwt.decode(token_factory(), OTHER_CERT_BYTES)
266    assert excinfo.match(r"Could not verify token signature")
267
268
269def test_decode_multicert_bad_cert(token_factory):
270    certs = {"1": OTHER_CERT_BYTES, "2": PUBLIC_CERT_BYTES}
271    with pytest.raises(ValueError) as excinfo:
272        jwt.decode(token_factory(), certs)
273    assert excinfo.match(r"Could not verify token signature")
274
275
276def test_decode_no_cert(token_factory):
277    certs = {"2": PUBLIC_CERT_BYTES}
278    with pytest.raises(ValueError) as excinfo:
279        jwt.decode(token_factory(), certs)
280    assert excinfo.match(r"Certificate for key id 1 not found")
281
282
283def test_decode_no_key_id(token_factory):
284    token = token_factory(key_id=False)
285    certs = {"2": PUBLIC_CERT_BYTES}
286    payload = jwt.decode(token, certs)
287    assert payload["user"] == "billy bob"
288
289
290def test_decode_unknown_alg():
291    headers = json.dumps({u"kid": u"1", u"alg": u"fakealg"})
292    token = b".".join(
293        map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"])
294    )
295
296    with pytest.raises(ValueError) as excinfo:
297        jwt.decode(token)
298    assert excinfo.match(r"fakealg")
299
300
301def test_decode_missing_crytography_alg(monkeypatch):
302    monkeypatch.delitem(jwt._ALGORITHM_TO_VERIFIER_CLASS, "ES256")
303    headers = json.dumps({u"kid": u"1", u"alg": u"ES256"})
304    token = b".".join(
305        map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"])
306    )
307
308    with pytest.raises(ValueError) as excinfo:
309        jwt.decode(token)
310    assert excinfo.match(r"cryptography")
311
312
313def test_roundtrip_explicit_key_id(token_factory):
314    token = token_factory(key_id="3")
315    certs = {"2": OTHER_CERT_BYTES, "3": PUBLIC_CERT_BYTES}
316    payload = jwt.decode(token, certs)
317    assert payload["user"] == "billy bob"
318
319
320class TestCredentials(object):
321    SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
322    SUBJECT = "subject"
323    AUDIENCE = "audience"
324    ADDITIONAL_CLAIMS = {"meta": "data"}
325    credentials = None
326
327    @pytest.fixture(autouse=True)
328    def credentials_fixture(self, signer):
329        self.credentials = jwt.Credentials(
330            signer,
331            self.SERVICE_ACCOUNT_EMAIL,
332            self.SERVICE_ACCOUNT_EMAIL,
333            self.AUDIENCE,
334        )
335
336    def test_from_service_account_info(self):
337        with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
338            info = json.load(fh)
339
340        credentials = jwt.Credentials.from_service_account_info(
341            info, audience=self.AUDIENCE
342        )
343
344        assert credentials._signer.key_id == info["private_key_id"]
345        assert credentials._issuer == info["client_email"]
346        assert credentials._subject == info["client_email"]
347        assert credentials._audience == self.AUDIENCE
348
349    def test_from_service_account_info_args(self):
350        info = SERVICE_ACCOUNT_INFO.copy()
351
352        credentials = jwt.Credentials.from_service_account_info(
353            info,
354            subject=self.SUBJECT,
355            audience=self.AUDIENCE,
356            additional_claims=self.ADDITIONAL_CLAIMS,
357        )
358
359        assert credentials._signer.key_id == info["private_key_id"]
360        assert credentials._issuer == info["client_email"]
361        assert credentials._subject == self.SUBJECT
362        assert credentials._audience == self.AUDIENCE
363        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
364
365    def test_from_service_account_file(self):
366        info = SERVICE_ACCOUNT_INFO.copy()
367
368        credentials = jwt.Credentials.from_service_account_file(
369            SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE
370        )
371
372        assert credentials._signer.key_id == info["private_key_id"]
373        assert credentials._issuer == info["client_email"]
374        assert credentials._subject == info["client_email"]
375        assert credentials._audience == self.AUDIENCE
376
377    def test_from_service_account_file_args(self):
378        info = SERVICE_ACCOUNT_INFO.copy()
379
380        credentials = jwt.Credentials.from_service_account_file(
381            SERVICE_ACCOUNT_JSON_FILE,
382            subject=self.SUBJECT,
383            audience=self.AUDIENCE,
384            additional_claims=self.ADDITIONAL_CLAIMS,
385        )
386
387        assert credentials._signer.key_id == info["private_key_id"]
388        assert credentials._issuer == info["client_email"]
389        assert credentials._subject == self.SUBJECT
390        assert credentials._audience == self.AUDIENCE
391        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
392
393    def test_from_signing_credentials(self):
394        jwt_from_signing = self.credentials.from_signing_credentials(
395            self.credentials, audience=mock.sentinel.new_audience
396        )
397        jwt_from_info = jwt.Credentials.from_service_account_info(
398            SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience
399        )
400
401        assert isinstance(jwt_from_signing, jwt.Credentials)
402        assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
403        assert jwt_from_signing._issuer == jwt_from_info._issuer
404        assert jwt_from_signing._subject == jwt_from_info._subject
405        assert jwt_from_signing._audience == jwt_from_info._audience
406
407    def test_default_state(self):
408        assert not self.credentials.valid
409        # Expiration hasn't been set yet
410        assert not self.credentials.expired
411
412    def test_with_claims(self):
413        new_audience = "new_audience"
414        new_credentials = self.credentials.with_claims(audience=new_audience)
415
416        assert new_credentials._signer == self.credentials._signer
417        assert new_credentials._issuer == self.credentials._issuer
418        assert new_credentials._subject == self.credentials._subject
419        assert new_credentials._audience == new_audience
420        assert new_credentials._additional_claims == self.credentials._additional_claims
421        assert new_credentials._quota_project_id == self.credentials._quota_project_id
422
423    def test__make_jwt_without_audience(self):
424        cred = jwt.Credentials.from_service_account_info(
425            SERVICE_ACCOUNT_INFO.copy(),
426            subject=self.SUBJECT,
427            audience=None,
428            additional_claims={"scope": "foo bar"},
429        )
430        token, _ = cred._make_jwt()
431        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
432        assert payload["scope"] == "foo bar"
433        assert "aud" not in payload
434
435    def test_with_quota_project(self):
436        quota_project_id = "project-foo"
437
438        new_credentials = self.credentials.with_quota_project(quota_project_id)
439        assert new_credentials._signer == self.credentials._signer
440        assert new_credentials._issuer == self.credentials._issuer
441        assert new_credentials._subject == self.credentials._subject
442        assert new_credentials._audience == self.credentials._audience
443        assert new_credentials._additional_claims == self.credentials._additional_claims
444        assert new_credentials._quota_project_id == quota_project_id
445
446    def test_sign_bytes(self):
447        to_sign = b"123"
448        signature = self.credentials.sign_bytes(to_sign)
449        assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
450
451    def test_signer(self):
452        assert isinstance(self.credentials.signer, crypt.RSASigner)
453
454    def test_signer_email(self):
455        assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"]
456
457    def _verify_token(self, token):
458        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
459        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
460        return payload
461
462    def test_refresh(self):
463        self.credentials.refresh(None)
464        assert self.credentials.valid
465        assert not self.credentials.expired
466
467    def test_expired(self):
468        assert not self.credentials.expired
469
470        self.credentials.refresh(None)
471        assert not self.credentials.expired
472
473        with mock.patch("google.auth._helpers.utcnow") as now:
474            one_day = datetime.timedelta(days=1)
475            now.return_value = self.credentials.expiry + one_day
476            assert self.credentials.expired
477
478    def test_before_request(self):
479        headers = {}
480
481        self.credentials.refresh(None)
482        self.credentials.before_request(
483            None, "GET", "http://example.com?a=1#3", headers
484        )
485
486        header_value = headers["authorization"]
487        _, token = header_value.split(" ")
488
489        # Since the audience is set, it should use the existing token.
490        assert token.encode("utf-8") == self.credentials.token
491
492        payload = self._verify_token(token)
493        assert payload["aud"] == self.AUDIENCE
494
495    def test_before_request_refreshes(self):
496        assert not self.credentials.valid
497        self.credentials.before_request(None, "GET", "http://example.com?a=1#3", {})
498        assert self.credentials.valid
499
500
501class TestOnDemandCredentials(object):
502    SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
503    SUBJECT = "subject"
504    ADDITIONAL_CLAIMS = {"meta": "data"}
505    credentials = None
506
507    @pytest.fixture(autouse=True)
508    def credentials_fixture(self, signer):
509        self.credentials = jwt.OnDemandCredentials(
510            signer,
511            self.SERVICE_ACCOUNT_EMAIL,
512            self.SERVICE_ACCOUNT_EMAIL,
513            max_cache_size=2,
514        )
515
516    def test_from_service_account_info(self):
517        with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
518            info = json.load(fh)
519
520        credentials = jwt.OnDemandCredentials.from_service_account_info(info)
521
522        assert credentials._signer.key_id == info["private_key_id"]
523        assert credentials._issuer == info["client_email"]
524        assert credentials._subject == info["client_email"]
525
526    def test_from_service_account_info_args(self):
527        info = SERVICE_ACCOUNT_INFO.copy()
528
529        credentials = jwt.OnDemandCredentials.from_service_account_info(
530            info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS
531        )
532
533        assert credentials._signer.key_id == info["private_key_id"]
534        assert credentials._issuer == info["client_email"]
535        assert credentials._subject == self.SUBJECT
536        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
537
538    def test_from_service_account_file(self):
539        info = SERVICE_ACCOUNT_INFO.copy()
540
541        credentials = jwt.OnDemandCredentials.from_service_account_file(
542            SERVICE_ACCOUNT_JSON_FILE
543        )
544
545        assert credentials._signer.key_id == info["private_key_id"]
546        assert credentials._issuer == info["client_email"]
547        assert credentials._subject == info["client_email"]
548
549    def test_from_service_account_file_args(self):
550        info = SERVICE_ACCOUNT_INFO.copy()
551
552        credentials = jwt.OnDemandCredentials.from_service_account_file(
553            SERVICE_ACCOUNT_JSON_FILE,
554            subject=self.SUBJECT,
555            additional_claims=self.ADDITIONAL_CLAIMS,
556        )
557
558        assert credentials._signer.key_id == info["private_key_id"]
559        assert credentials._issuer == info["client_email"]
560        assert credentials._subject == self.SUBJECT
561        assert credentials._additional_claims == self.ADDITIONAL_CLAIMS
562
563    def test_from_signing_credentials(self):
564        jwt_from_signing = self.credentials.from_signing_credentials(self.credentials)
565        jwt_from_info = jwt.OnDemandCredentials.from_service_account_info(
566            SERVICE_ACCOUNT_INFO
567        )
568
569        assert isinstance(jwt_from_signing, jwt.OnDemandCredentials)
570        assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id
571        assert jwt_from_signing._issuer == jwt_from_info._issuer
572        assert jwt_from_signing._subject == jwt_from_info._subject
573
574    def test_default_state(self):
575        # Credentials are *always* valid.
576        assert self.credentials.valid
577        # Credentials *never* expire.
578        assert not self.credentials.expired
579
580    def test_with_claims(self):
581        new_claims = {"meep": "moop"}
582        new_credentials = self.credentials.with_claims(additional_claims=new_claims)
583
584        assert new_credentials._signer == self.credentials._signer
585        assert new_credentials._issuer == self.credentials._issuer
586        assert new_credentials._subject == self.credentials._subject
587        assert new_credentials._additional_claims == new_claims
588
589    def test_with_quota_project(self):
590        quota_project_id = "project-foo"
591        new_credentials = self.credentials.with_quota_project(quota_project_id)
592
593        assert new_credentials._signer == self.credentials._signer
594        assert new_credentials._issuer == self.credentials._issuer
595        assert new_credentials._subject == self.credentials._subject
596        assert new_credentials._additional_claims == self.credentials._additional_claims
597        assert new_credentials._quota_project_id == quota_project_id
598
599    def test_sign_bytes(self):
600        to_sign = b"123"
601        signature = self.credentials.sign_bytes(to_sign)
602        assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES)
603
604    def test_signer(self):
605        assert isinstance(self.credentials.signer, crypt.RSASigner)
606
607    def test_signer_email(self):
608        assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"]
609
610    def _verify_token(self, token):
611        payload = jwt.decode(token, PUBLIC_CERT_BYTES)
612        assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL
613        return payload
614
615    def test_refresh(self):
616        with pytest.raises(exceptions.RefreshError):
617            self.credentials.refresh(None)
618
619    def test_before_request(self):
620        headers = {}
621
622        self.credentials.before_request(
623            None, "GET", "http://example.com?a=1#3", headers
624        )
625
626        _, token = headers["authorization"].split(" ")
627        payload = self._verify_token(token)
628
629        assert payload["aud"] == "http://example.com"
630
631        # Making another request should re-use the same token.
632        self.credentials.before_request(None, "GET", "http://example.com?b=2", headers)
633
634        _, new_token = headers["authorization"].split(" ")
635
636        assert new_token == token
637
638    def test_expired_token(self):
639        self.credentials._cache["audience"] = (
640            mock.sentinel.token,
641            datetime.datetime.min,
642        )
643
644        token = self.credentials._get_jwt_for_audience("audience")
645
646        assert token != mock.sentinel.token
647