1# Copyright 2014 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import base64 16import datetime 17import json 18import os 19 20import mock 21import pytest 22 23from google.auth import _helpers 24from google.auth import crypt 25from google.auth import exceptions 26from google.auth import jwt 27 28 29DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 30 31with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: 32 PRIVATE_KEY_BYTES = fh.read() 33 34with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: 35 PUBLIC_CERT_BYTES = fh.read() 36 37with open(os.path.join(DATA_DIR, "other_cert.pem"), "rb") as fh: 38 OTHER_CERT_BYTES = fh.read() 39 40with open(os.path.join(DATA_DIR, "es256_privatekey.pem"), "rb") as fh: 41 EC_PRIVATE_KEY_BYTES = fh.read() 42 43with open(os.path.join(DATA_DIR, "es256_public_cert.pem"), "rb") as fh: 44 EC_PUBLIC_CERT_BYTES = fh.read() 45 46SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json") 47 48with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 49 SERVICE_ACCOUNT_INFO = json.load(fh) 50 51 52@pytest.fixture 53def signer(): 54 return crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1") 55 56 57def test_encode_basic(signer): 58 test_payload = {"test": "value"} 59 encoded = jwt.encode(signer, test_payload) 60 header, payload, _, _ = jwt._unverified_decode(encoded) 61 assert payload == test_payload 62 assert header == {"typ": "JWT", "alg": "RS256", "kid": signer.key_id} 63 64 65def test_encode_extra_headers(signer): 66 encoded = jwt.encode(signer, {}, header={"extra": "value"}) 67 header = jwt.decode_header(encoded) 68 assert header == { 69 "typ": "JWT", 70 "alg": "RS256", 71 "kid": signer.key_id, 72 "extra": "value", 73 } 74 75 76def test_encode_custom_alg_in_headers(signer): 77 encoded = jwt.encode(signer, {}, header={"alg": "foo"}) 78 header = jwt.decode_header(encoded) 79 assert header == {"typ": "JWT", "alg": "foo", "kid": signer.key_id} 80 81 82@pytest.fixture 83def es256_signer(): 84 return crypt.ES256Signer.from_string(EC_PRIVATE_KEY_BYTES, "1") 85 86 87def test_encode_basic_es256(es256_signer): 88 test_payload = {"test": "value"} 89 encoded = jwt.encode(es256_signer, test_payload) 90 header, payload, _, _ = jwt._unverified_decode(encoded) 91 assert payload == test_payload 92 assert header == {"typ": "JWT", "alg": "ES256", "kid": es256_signer.key_id} 93 94 95@pytest.fixture 96def token_factory(signer, es256_signer): 97 def factory(claims=None, key_id=None, use_es256_signer=False): 98 now = _helpers.datetime_to_secs(_helpers.utcnow()) 99 payload = { 100 "aud": "audience@example.com", 101 "iat": now, 102 "exp": now + 300, 103 "user": "billy bob", 104 "metadata": {"meta": "data"}, 105 } 106 payload.update(claims or {}) 107 108 # False is specified to remove the signer's key id for testing 109 # headers without key ids. 110 if key_id is False: 111 signer._key_id = None 112 key_id = None 113 114 if use_es256_signer: 115 return jwt.encode(es256_signer, payload, key_id=key_id) 116 else: 117 return jwt.encode(signer, payload, key_id=key_id) 118 119 return factory 120 121 122def test_decode_valid(token_factory): 123 payload = jwt.decode(token_factory(), certs=PUBLIC_CERT_BYTES) 124 assert payload["aud"] == "audience@example.com" 125 assert payload["user"] == "billy bob" 126 assert payload["metadata"]["meta"] == "data" 127 128 129def test_decode_valid_es256(token_factory): 130 payload = jwt.decode( 131 token_factory(use_es256_signer=True), certs=EC_PUBLIC_CERT_BYTES 132 ) 133 assert payload["aud"] == "audience@example.com" 134 assert payload["user"] == "billy bob" 135 assert payload["metadata"]["meta"] == "data" 136 137 138def test_decode_valid_with_audience(token_factory): 139 payload = jwt.decode( 140 token_factory(), certs=PUBLIC_CERT_BYTES, audience="audience@example.com" 141 ) 142 assert payload["aud"] == "audience@example.com" 143 assert payload["user"] == "billy bob" 144 assert payload["metadata"]["meta"] == "data" 145 146 147def test_decode_valid_with_audience_list(token_factory): 148 payload = jwt.decode( 149 token_factory(), 150 certs=PUBLIC_CERT_BYTES, 151 audience=["audience@example.com", "another_audience@example.com"], 152 ) 153 assert payload["aud"] == "audience@example.com" 154 assert payload["user"] == "billy bob" 155 assert payload["metadata"]["meta"] == "data" 156 157 158def test_decode_valid_unverified(token_factory): 159 payload = jwt.decode(token_factory(), certs=OTHER_CERT_BYTES, verify=False) 160 assert payload["aud"] == "audience@example.com" 161 assert payload["user"] == "billy bob" 162 assert payload["metadata"]["meta"] == "data" 163 164 165def test_decode_bad_token_wrong_number_of_segments(): 166 with pytest.raises(ValueError) as excinfo: 167 jwt.decode("1.2", PUBLIC_CERT_BYTES) 168 assert excinfo.match(r"Wrong number of segments") 169 170 171def test_decode_bad_token_not_base64(): 172 with pytest.raises((ValueError, TypeError)) as excinfo: 173 jwt.decode("1.2.3", PUBLIC_CERT_BYTES) 174 assert excinfo.match(r"Incorrect padding|more than a multiple of 4") 175 176 177def test_decode_bad_token_not_json(): 178 token = b".".join([base64.urlsafe_b64encode(b"123!")] * 3) 179 with pytest.raises(ValueError) as excinfo: 180 jwt.decode(token, PUBLIC_CERT_BYTES) 181 assert excinfo.match(r"Can\'t parse segment") 182 183 184def test_decode_bad_token_no_iat_or_exp(signer): 185 token = jwt.encode(signer, {"test": "value"}) 186 with pytest.raises(ValueError) as excinfo: 187 jwt.decode(token, PUBLIC_CERT_BYTES) 188 assert excinfo.match(r"Token does not contain required claim") 189 190 191def test_decode_bad_token_too_early(token_factory): 192 token = token_factory( 193 claims={ 194 "iat": _helpers.datetime_to_secs( 195 _helpers.utcnow() + datetime.timedelta(hours=1) 196 ) 197 } 198 ) 199 with pytest.raises(ValueError) as excinfo: 200 jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=59) 201 assert excinfo.match(r"Token used too early") 202 203 204def test_decode_bad_token_expired(token_factory): 205 token = token_factory( 206 claims={ 207 "exp": _helpers.datetime_to_secs( 208 _helpers.utcnow() - datetime.timedelta(hours=1) 209 ) 210 } 211 ) 212 with pytest.raises(ValueError) as excinfo: 213 jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=59) 214 assert excinfo.match(r"Token expired") 215 216 217def test_decode_success_with_no_clock_skew(token_factory): 218 token = token_factory( 219 claims={ 220 "exp": _helpers.datetime_to_secs( 221 _helpers.utcnow() + datetime.timedelta(seconds=1) 222 ), 223 "iat": _helpers.datetime_to_secs( 224 _helpers.utcnow() - datetime.timedelta(seconds=1) 225 ), 226 } 227 ) 228 229 jwt.decode(token, PUBLIC_CERT_BYTES) 230 231 232def test_decode_success_with_custom_clock_skew(token_factory): 233 token = token_factory( 234 claims={ 235 "exp": _helpers.datetime_to_secs( 236 _helpers.utcnow() + datetime.timedelta(seconds=2) 237 ), 238 "iat": _helpers.datetime_to_secs( 239 _helpers.utcnow() - datetime.timedelta(seconds=2) 240 ), 241 } 242 ) 243 244 jwt.decode(token, PUBLIC_CERT_BYTES, clock_skew_in_seconds=1) 245 246 247def test_decode_bad_token_wrong_audience(token_factory): 248 token = token_factory() 249 audience = "audience2@example.com" 250 with pytest.raises(ValueError) as excinfo: 251 jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience) 252 assert excinfo.match(r"Token has wrong audience") 253 254 255def test_decode_bad_token_wrong_audience_list(token_factory): 256 token = token_factory() 257 audience = ["audience2@example.com", "audience3@example.com"] 258 with pytest.raises(ValueError) as excinfo: 259 jwt.decode(token, PUBLIC_CERT_BYTES, audience=audience) 260 assert excinfo.match(r"Token has wrong audience") 261 262 263def test_decode_wrong_cert(token_factory): 264 with pytest.raises(ValueError) as excinfo: 265 jwt.decode(token_factory(), OTHER_CERT_BYTES) 266 assert excinfo.match(r"Could not verify token signature") 267 268 269def test_decode_multicert_bad_cert(token_factory): 270 certs = {"1": OTHER_CERT_BYTES, "2": PUBLIC_CERT_BYTES} 271 with pytest.raises(ValueError) as excinfo: 272 jwt.decode(token_factory(), certs) 273 assert excinfo.match(r"Could not verify token signature") 274 275 276def test_decode_no_cert(token_factory): 277 certs = {"2": PUBLIC_CERT_BYTES} 278 with pytest.raises(ValueError) as excinfo: 279 jwt.decode(token_factory(), certs) 280 assert excinfo.match(r"Certificate for key id 1 not found") 281 282 283def test_decode_no_key_id(token_factory): 284 token = token_factory(key_id=False) 285 certs = {"2": PUBLIC_CERT_BYTES} 286 payload = jwt.decode(token, certs) 287 assert payload["user"] == "billy bob" 288 289 290def test_decode_unknown_alg(): 291 headers = json.dumps({u"kid": u"1", u"alg": u"fakealg"}) 292 token = b".".join( 293 map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) 294 ) 295 296 with pytest.raises(ValueError) as excinfo: 297 jwt.decode(token) 298 assert excinfo.match(r"fakealg") 299 300 301def test_decode_missing_crytography_alg(monkeypatch): 302 monkeypatch.delitem(jwt._ALGORITHM_TO_VERIFIER_CLASS, "ES256") 303 headers = json.dumps({u"kid": u"1", u"alg": u"ES256"}) 304 token = b".".join( 305 map(lambda seg: base64.b64encode(seg.encode("utf-8")), [headers, u"{}", u"sig"]) 306 ) 307 308 with pytest.raises(ValueError) as excinfo: 309 jwt.decode(token) 310 assert excinfo.match(r"cryptography") 311 312 313def test_roundtrip_explicit_key_id(token_factory): 314 token = token_factory(key_id="3") 315 certs = {"2": OTHER_CERT_BYTES, "3": PUBLIC_CERT_BYTES} 316 payload = jwt.decode(token, certs) 317 assert payload["user"] == "billy bob" 318 319 320class TestCredentials(object): 321 SERVICE_ACCOUNT_EMAIL = "service-account@example.com" 322 SUBJECT = "subject" 323 AUDIENCE = "audience" 324 ADDITIONAL_CLAIMS = {"meta": "data"} 325 credentials = None 326 327 @pytest.fixture(autouse=True) 328 def credentials_fixture(self, signer): 329 self.credentials = jwt.Credentials( 330 signer, 331 self.SERVICE_ACCOUNT_EMAIL, 332 self.SERVICE_ACCOUNT_EMAIL, 333 self.AUDIENCE, 334 ) 335 336 def test_from_service_account_info(self): 337 with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 338 info = json.load(fh) 339 340 credentials = jwt.Credentials.from_service_account_info( 341 info, audience=self.AUDIENCE 342 ) 343 344 assert credentials._signer.key_id == info["private_key_id"] 345 assert credentials._issuer == info["client_email"] 346 assert credentials._subject == info["client_email"] 347 assert credentials._audience == self.AUDIENCE 348 349 def test_from_service_account_info_args(self): 350 info = SERVICE_ACCOUNT_INFO.copy() 351 352 credentials = jwt.Credentials.from_service_account_info( 353 info, 354 subject=self.SUBJECT, 355 audience=self.AUDIENCE, 356 additional_claims=self.ADDITIONAL_CLAIMS, 357 ) 358 359 assert credentials._signer.key_id == info["private_key_id"] 360 assert credentials._issuer == info["client_email"] 361 assert credentials._subject == self.SUBJECT 362 assert credentials._audience == self.AUDIENCE 363 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 364 365 def test_from_service_account_file(self): 366 info = SERVICE_ACCOUNT_INFO.copy() 367 368 credentials = jwt.Credentials.from_service_account_file( 369 SERVICE_ACCOUNT_JSON_FILE, audience=self.AUDIENCE 370 ) 371 372 assert credentials._signer.key_id == info["private_key_id"] 373 assert credentials._issuer == info["client_email"] 374 assert credentials._subject == info["client_email"] 375 assert credentials._audience == self.AUDIENCE 376 377 def test_from_service_account_file_args(self): 378 info = SERVICE_ACCOUNT_INFO.copy() 379 380 credentials = jwt.Credentials.from_service_account_file( 381 SERVICE_ACCOUNT_JSON_FILE, 382 subject=self.SUBJECT, 383 audience=self.AUDIENCE, 384 additional_claims=self.ADDITIONAL_CLAIMS, 385 ) 386 387 assert credentials._signer.key_id == info["private_key_id"] 388 assert credentials._issuer == info["client_email"] 389 assert credentials._subject == self.SUBJECT 390 assert credentials._audience == self.AUDIENCE 391 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 392 393 def test_from_signing_credentials(self): 394 jwt_from_signing = self.credentials.from_signing_credentials( 395 self.credentials, audience=mock.sentinel.new_audience 396 ) 397 jwt_from_info = jwt.Credentials.from_service_account_info( 398 SERVICE_ACCOUNT_INFO, audience=mock.sentinel.new_audience 399 ) 400 401 assert isinstance(jwt_from_signing, jwt.Credentials) 402 assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id 403 assert jwt_from_signing._issuer == jwt_from_info._issuer 404 assert jwt_from_signing._subject == jwt_from_info._subject 405 assert jwt_from_signing._audience == jwt_from_info._audience 406 407 def test_default_state(self): 408 assert not self.credentials.valid 409 # Expiration hasn't been set yet 410 assert not self.credentials.expired 411 412 def test_with_claims(self): 413 new_audience = "new_audience" 414 new_credentials = self.credentials.with_claims(audience=new_audience) 415 416 assert new_credentials._signer == self.credentials._signer 417 assert new_credentials._issuer == self.credentials._issuer 418 assert new_credentials._subject == self.credentials._subject 419 assert new_credentials._audience == new_audience 420 assert new_credentials._additional_claims == self.credentials._additional_claims 421 assert new_credentials._quota_project_id == self.credentials._quota_project_id 422 423 def test__make_jwt_without_audience(self): 424 cred = jwt.Credentials.from_service_account_info( 425 SERVICE_ACCOUNT_INFO.copy(), 426 subject=self.SUBJECT, 427 audience=None, 428 additional_claims={"scope": "foo bar"}, 429 ) 430 token, _ = cred._make_jwt() 431 payload = jwt.decode(token, PUBLIC_CERT_BYTES) 432 assert payload["scope"] == "foo bar" 433 assert "aud" not in payload 434 435 def test_with_quota_project(self): 436 quota_project_id = "project-foo" 437 438 new_credentials = self.credentials.with_quota_project(quota_project_id) 439 assert new_credentials._signer == self.credentials._signer 440 assert new_credentials._issuer == self.credentials._issuer 441 assert new_credentials._subject == self.credentials._subject 442 assert new_credentials._audience == self.credentials._audience 443 assert new_credentials._additional_claims == self.credentials._additional_claims 444 assert new_credentials._quota_project_id == quota_project_id 445 446 def test_sign_bytes(self): 447 to_sign = b"123" 448 signature = self.credentials.sign_bytes(to_sign) 449 assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) 450 451 def test_signer(self): 452 assert isinstance(self.credentials.signer, crypt.RSASigner) 453 454 def test_signer_email(self): 455 assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] 456 457 def _verify_token(self, token): 458 payload = jwt.decode(token, PUBLIC_CERT_BYTES) 459 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 460 return payload 461 462 def test_refresh(self): 463 self.credentials.refresh(None) 464 assert self.credentials.valid 465 assert not self.credentials.expired 466 467 def test_expired(self): 468 assert not self.credentials.expired 469 470 self.credentials.refresh(None) 471 assert not self.credentials.expired 472 473 with mock.patch("google.auth._helpers.utcnow") as now: 474 one_day = datetime.timedelta(days=1) 475 now.return_value = self.credentials.expiry + one_day 476 assert self.credentials.expired 477 478 def test_before_request(self): 479 headers = {} 480 481 self.credentials.refresh(None) 482 self.credentials.before_request( 483 None, "GET", "http://example.com?a=1#3", headers 484 ) 485 486 header_value = headers["authorization"] 487 _, token = header_value.split(" ") 488 489 # Since the audience is set, it should use the existing token. 490 assert token.encode("utf-8") == self.credentials.token 491 492 payload = self._verify_token(token) 493 assert payload["aud"] == self.AUDIENCE 494 495 def test_before_request_refreshes(self): 496 assert not self.credentials.valid 497 self.credentials.before_request(None, "GET", "http://example.com?a=1#3", {}) 498 assert self.credentials.valid 499 500 501class TestOnDemandCredentials(object): 502 SERVICE_ACCOUNT_EMAIL = "service-account@example.com" 503 SUBJECT = "subject" 504 ADDITIONAL_CLAIMS = {"meta": "data"} 505 credentials = None 506 507 @pytest.fixture(autouse=True) 508 def credentials_fixture(self, signer): 509 self.credentials = jwt.OnDemandCredentials( 510 signer, 511 self.SERVICE_ACCOUNT_EMAIL, 512 self.SERVICE_ACCOUNT_EMAIL, 513 max_cache_size=2, 514 ) 515 516 def test_from_service_account_info(self): 517 with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh: 518 info = json.load(fh) 519 520 credentials = jwt.OnDemandCredentials.from_service_account_info(info) 521 522 assert credentials._signer.key_id == info["private_key_id"] 523 assert credentials._issuer == info["client_email"] 524 assert credentials._subject == info["client_email"] 525 526 def test_from_service_account_info_args(self): 527 info = SERVICE_ACCOUNT_INFO.copy() 528 529 credentials = jwt.OnDemandCredentials.from_service_account_info( 530 info, subject=self.SUBJECT, additional_claims=self.ADDITIONAL_CLAIMS 531 ) 532 533 assert credentials._signer.key_id == info["private_key_id"] 534 assert credentials._issuer == info["client_email"] 535 assert credentials._subject == self.SUBJECT 536 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 537 538 def test_from_service_account_file(self): 539 info = SERVICE_ACCOUNT_INFO.copy() 540 541 credentials = jwt.OnDemandCredentials.from_service_account_file( 542 SERVICE_ACCOUNT_JSON_FILE 543 ) 544 545 assert credentials._signer.key_id == info["private_key_id"] 546 assert credentials._issuer == info["client_email"] 547 assert credentials._subject == info["client_email"] 548 549 def test_from_service_account_file_args(self): 550 info = SERVICE_ACCOUNT_INFO.copy() 551 552 credentials = jwt.OnDemandCredentials.from_service_account_file( 553 SERVICE_ACCOUNT_JSON_FILE, 554 subject=self.SUBJECT, 555 additional_claims=self.ADDITIONAL_CLAIMS, 556 ) 557 558 assert credentials._signer.key_id == info["private_key_id"] 559 assert credentials._issuer == info["client_email"] 560 assert credentials._subject == self.SUBJECT 561 assert credentials._additional_claims == self.ADDITIONAL_CLAIMS 562 563 def test_from_signing_credentials(self): 564 jwt_from_signing = self.credentials.from_signing_credentials(self.credentials) 565 jwt_from_info = jwt.OnDemandCredentials.from_service_account_info( 566 SERVICE_ACCOUNT_INFO 567 ) 568 569 assert isinstance(jwt_from_signing, jwt.OnDemandCredentials) 570 assert jwt_from_signing._signer.key_id == jwt_from_info._signer.key_id 571 assert jwt_from_signing._issuer == jwt_from_info._issuer 572 assert jwt_from_signing._subject == jwt_from_info._subject 573 574 def test_default_state(self): 575 # Credentials are *always* valid. 576 assert self.credentials.valid 577 # Credentials *never* expire. 578 assert not self.credentials.expired 579 580 def test_with_claims(self): 581 new_claims = {"meep": "moop"} 582 new_credentials = self.credentials.with_claims(additional_claims=new_claims) 583 584 assert new_credentials._signer == self.credentials._signer 585 assert new_credentials._issuer == self.credentials._issuer 586 assert new_credentials._subject == self.credentials._subject 587 assert new_credentials._additional_claims == new_claims 588 589 def test_with_quota_project(self): 590 quota_project_id = "project-foo" 591 new_credentials = self.credentials.with_quota_project(quota_project_id) 592 593 assert new_credentials._signer == self.credentials._signer 594 assert new_credentials._issuer == self.credentials._issuer 595 assert new_credentials._subject == self.credentials._subject 596 assert new_credentials._additional_claims == self.credentials._additional_claims 597 assert new_credentials._quota_project_id == quota_project_id 598 599 def test_sign_bytes(self): 600 to_sign = b"123" 601 signature = self.credentials.sign_bytes(to_sign) 602 assert crypt.verify_signature(to_sign, signature, PUBLIC_CERT_BYTES) 603 604 def test_signer(self): 605 assert isinstance(self.credentials.signer, crypt.RSASigner) 606 607 def test_signer_email(self): 608 assert self.credentials.signer_email == SERVICE_ACCOUNT_INFO["client_email"] 609 610 def _verify_token(self, token): 611 payload = jwt.decode(token, PUBLIC_CERT_BYTES) 612 assert payload["iss"] == self.SERVICE_ACCOUNT_EMAIL 613 return payload 614 615 def test_refresh(self): 616 with pytest.raises(exceptions.RefreshError): 617 self.credentials.refresh(None) 618 619 def test_before_request(self): 620 headers = {} 621 622 self.credentials.before_request( 623 None, "GET", "http://example.com?a=1#3", headers 624 ) 625 626 _, token = headers["authorization"].split(" ") 627 payload = self._verify_token(token) 628 629 assert payload["aud"] == "http://example.com" 630 631 # Making another request should re-use the same token. 632 self.credentials.before_request(None, "GET", "http://example.com?b=2", headers) 633 634 _, new_token = headers["authorization"].split(" ") 635 636 assert new_token == token 637 638 def test_expired_token(self): 639 self.credentials._cache["audience"] = ( 640 mock.sentinel.token, 641 datetime.datetime.min, 642 ) 643 644 token = self.credentials._get_jwt_for_audience("audience") 645 646 assert token != mock.sentinel.token 647