1import json 2import urllib.request 3from functools import lru_cache 4from typing import Any, List 5 6from .api_jwk import PyJWK, PyJWKSet 7from .api_jwt import decode_complete as decode_token 8from .exceptions import PyJWKClientError 9 10 11class PyJWKClient: 12 def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): 13 self.uri = uri 14 if cache_keys: 15 # Cache signing keys 16 # Ignore mypy (https://github.com/python/mypy/issues/2427) 17 self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore 18 19 def fetch_data(self) -> Any: 20 with urllib.request.urlopen(self.uri) as response: 21 return json.load(response) 22 23 def get_jwk_set(self) -> PyJWKSet: 24 data = self.fetch_data() 25 return PyJWKSet.from_dict(data) 26 27 def get_signing_keys(self) -> List[PyJWK]: 28 jwk_set = self.get_jwk_set() 29 signing_keys = [ 30 jwk_set_key 31 for jwk_set_key in jwk_set.keys 32 if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id 33 ] 34 35 if not signing_keys: 36 raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") 37 38 return signing_keys 39 40 def get_signing_key(self, kid: str) -> PyJWK: 41 signing_keys = self.get_signing_keys() 42 signing_key = None 43 44 for key in signing_keys: 45 if key.key_id == kid: 46 signing_key = key 47 break 48 49 if not signing_key: 50 raise PyJWKClientError( 51 f'Unable to find a signing key that matches: "{kid}"' 52 ) 53 54 return signing_key 55 56 def get_signing_key_from_jwt(self, token: str) -> PyJWK: 57 unverified = decode_token(token, options={"verify_signature": False}) 58 header = unverified["header"] 59 return self.get_signing_key(header.get("kid")) 60