1#------------------------------------------------------------------------------ 2# 3# Copyright (c) Microsoft Corporation. 4# All rights reserved. 5# 6# This code is licensed under the MIT License. 7# 8# Permission is hereby granted, free of charge, to any person obtaining a copy 9# of this software and associated documentation files(the "Software"), to deal 10# in the Software without restriction, including without limitation the rights 11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell 12# copies of the Software, and to permit persons to whom the Software is 13# furnished to do so, subject to the following conditions : 14# 15# The above copyright notice and this permission notice shall be included in 16# all copies or substantial portions of the Software. 17# 18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE 21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 24# THE SOFTWARE. 25# 26#------------------------------------------------------------------------------ 27 28import time 29import datetime 30import uuid 31import base64 32import binascii 33import re 34 35import jwt 36 37from .constants import Jwt 38from .log import Logger 39from .adal_error import AdalError 40 41def _get_date_now(): 42 return datetime.datetime.now() 43 44def _get_new_jwt_id(): 45 return str(uuid.uuid4()) 46 47def _create_x5t_value(thumbprint): 48 hex_val = binascii.a2b_hex(thumbprint) 49 return base64.urlsafe_b64encode(hex_val).decode() 50 51def _sign_jwt(header, payload, certificate): 52 try: 53 encoded_jwt = _encode_jwt(payload, certificate, header) 54 except Exception as exp: 55 raise AdalError("Error:Invalid Certificate: Expected Start of Certificate to be '-----BEGIN RSA PRIVATE KEY-----'", exp) 56 _raise_on_invalid_jwt_signature(encoded_jwt) 57 return encoded_jwt 58 59def _encode_jwt(payload, certificate, header): 60 encoded = jwt.encode(payload, certificate, algorithm='RS256', headers=header) 61 try: 62 return encoded.decode() # PyJWT 1.x returns bytes; historically we convert it to string 63 except AttributeError: 64 return encoded # PyJWT 2 will return string 65 66def _raise_on_invalid_jwt_signature(encoded_jwt): 67 segments = encoded_jwt.split('.') 68 if len(segments) < 3 or not segments[2]: 69 raise AdalError('Failed to sign JWT. This is most likely due to an invalid certificate.') 70 71def _extract_certs(public_cert_content): 72 # Parses raw public certificate file contents and returns a list of strings 73 # Usage: headers = {"x5c": extract_certs(open("my_cert.pem").read())} 74 public_certificates = re.findall( 75 r'-----BEGIN CERTIFICATE-----(?P<cert_value>[^-]+)-----END CERTIFICATE-----', 76 public_cert_content, re.I) 77 if public_certificates: 78 return [cert.strip() for cert in public_certificates] 79 # The public cert tags are not found in the input, 80 # let's make best effort to exclude a private key pem file. 81 if "PRIVATE KEY" in public_cert_content: 82 raise ValueError( 83 "We expect your public key but detect a private key instead") 84 return [public_cert_content.strip()] 85 86class SelfSignedJwt(object): 87 88 NumCharIn128BitHexString = 128/8*2 89 numCharIn160BitHexString = 160/8*2 90 ThumbprintRegEx = r"^[a-f\d]*$" 91 92 def __init__(self, call_context, authority, client_id): 93 self._log = Logger('SelfSignedJwt', call_context['log_context']) 94 self._call_context = call_context 95 96 self._authortiy = authority 97 self._token_endpoint = authority.token_endpoint 98 self._client_id = client_id 99 100 def _create_header(self, thumbprint, public_certificate): 101 x5t = _create_x5t_value(thumbprint) 102 header = {'typ':'JWT', 'alg':'RS256', 'x5t':x5t} 103 if public_certificate: 104 header['x5c'] = _extract_certs(public_certificate) 105 self._log.debug("Creating self signed JWT header. x5t: %(x5t)s, x5c: %(x5c)s", 106 {"x5t": x5t, "x5c": public_certificate}) 107 108 return header 109 110 def _create_payload(self): 111 now = _get_date_now() 112 minutes = datetime.timedelta(0, 0, 0, 0, Jwt.SELF_SIGNED_JWT_LIFETIME) 113 expires = now + minutes 114 115 self._log.debug( 116 'Creating self signed JWT payload. Expires: %(expires)s NotBefore: %(nbf)s', 117 {"expires": expires, "nbf": now}) 118 119 jwt_payload = {} 120 jwt_payload[Jwt.AUDIENCE] = self._token_endpoint 121 jwt_payload[Jwt.ISSUER] = self._client_id 122 jwt_payload[Jwt.SUBJECT] = self._client_id 123 jwt_payload[Jwt.NOT_BEFORE] = int(time.mktime(now.timetuple())) 124 jwt_payload[Jwt.EXPIRES_ON] = int(time.mktime(expires.timetuple())) 125 jwt_payload[Jwt.JWT_ID] = _get_new_jwt_id() 126 127 return jwt_payload 128 129 def _raise_on_invalid_thumbprint(self, thumbprint): 130 thumbprint_sizes = [self.NumCharIn128BitHexString, self.numCharIn160BitHexString] 131 size_ok = len(thumbprint) in thumbprint_sizes 132 if not size_ok or not re.search(self.ThumbprintRegEx, thumbprint): 133 raise AdalError("The thumbprint does not match a known format") 134 135 def _reduce_thumbprint(self, thumbprint): 136 canonical = thumbprint.lower().replace(' ', '').replace(':', '') 137 self._raise_on_invalid_thumbprint(canonical) 138 return canonical 139 140 def create(self, certificate, thumbprint, public_certificate): 141 thumbprint = self._reduce_thumbprint(thumbprint) 142 143 header = self._create_header(thumbprint, public_certificate) 144 payload = self._create_payload() 145 return _sign_jwt(header, payload, certificate) 146